about summary refs log tree commit diff
path: root/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch
blob: 6b77173b4eccf88873df2d9822262d80cc3c47a1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py
index 359083a9..3cd5f7be 100644
--- a/tinygrad/runtime/autogen/cuda.py
+++ b/tinygrad/runtime/autogen/cuda.py
@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'):
     return ctypes.cast(string, ctypes.POINTER(ctypes.c_char))
 
 
+NAME_TO_PATHS = {
+    "libcuda.so": ["@driverLink@/lib/libcuda.so"],
+    "libnvrtc.so": ["@libnvrtc@"],
+}
+def _try_dlopen(name):
+    try:
+        return ctypes.CDLL(name)
+    except OSError:
+        pass
+    for candidate in NAME_TO_PATHS.get(name, []):
+        try:
+            return ctypes.CDLL(candidate)
+        except OSError:
+	        pass
+    raise RuntimeError(f"{name} not found")
 
 _libraries = {}
-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda'))
-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc'))
+_libraries['libcuda.so'] = _try_dlopen('libcuda.so')
+_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so')
 
 
 cuuint32_t = ctypes.c_uint32