diff options
Diffstat (limited to 'pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch')
-rw-r--r-- | pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch b/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch new file mode 100644 index 0000000000000..6b77173b4eccf --- /dev/null +++ b/pkgs/development/python-modules/tinygrad/fix-dlopen-cuda.patch @@ -0,0 +1,32 @@ +diff --git a/tinygrad/runtime/autogen/cuda.py b/tinygrad/runtime/autogen/cuda.py +index 359083a9..3cd5f7be 100644 +--- a/tinygrad/runtime/autogen/cuda.py ++++ b/tinygrad/runtime/autogen/cuda.py +@@ -143,10 +143,25 @@ def char_pointer_cast(string, encoding='utf-8'): + return ctypes.cast(string, ctypes.POINTER(ctypes.c_char)) + + ++NAME_TO_PATHS = { ++ "libcuda.so": ["@driverLink@/lib/libcuda.so"], ++ "libnvrtc.so": ["@libnvrtc@"], ++} ++def _try_dlopen(name): ++ try: ++ return ctypes.CDLL(name) ++ except OSError: ++ pass ++ for candidate in NAME_TO_PATHS.get(name, []): ++ try: ++ return ctypes.CDLL(candidate) ++ except OSError: ++ pass ++ raise RuntimeError(f"{name} not found") + + _libraries = {} +-_libraries['libcuda.so'] = ctypes.CDLL(ctypes.util.find_library('cuda')) +-_libraries['libnvrtc.so'] = ctypes.CDLL(ctypes.util.find_library('nvrtc')) ++_libraries['libcuda.so'] = _try_dlopen('libcuda.so') ++_libraries['libnvrtc.so'] = _try_dlopen('libnvrtc.so') + + + cuuint32_t = ctypes.c_uint32 |