about summary refs log tree commit diff
path: root/pkgs/development/python-modules/cupy/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/cupy/default.nix')
-rw-r--r--pkgs/development/python-modules/cupy/default.nix46
1 files changed, 37 insertions, 9 deletions
diff --git a/pkgs/development/python-modules/cupy/default.nix b/pkgs/development/python-modules/cupy/default.nix
index e5de149fca14a..71defbb99b985 100644
--- a/pkgs/development/python-modules/cupy/default.nix
+++ b/pkgs/development/python-modules/cupy/default.nix
@@ -11,11 +11,34 @@
 , cudaPackages
 , addOpenGLRunpath
 , pythonOlder
+, symlinkJoin
 }:
 
 let
-  inherit (cudaPackages) cudatoolkit cudnn cutensor nccl;
-in buildPythonPackage rec {
+  inherit (cudaPackages) cudnn cutensor nccl;
+  cudatoolkit-joined = symlinkJoin {
+    name = "cudatoolkit-joined-${cudaPackages.cudaVersion}";
+    paths = with cudaPackages; [
+      cuda_cccl # <nv/target>
+      cuda_cccl.dev
+      cuda_cudart
+      cuda_nvcc.dev # <crt/host_defines.h>
+      cuda_nvprof
+      cuda_nvrtc
+      cuda_nvtx
+      cuda_profiler_api
+      libcublas
+      libcufft
+      libcurand
+      libcusolver
+      libcusparse
+
+      # Missing:
+      # cusparselt
+    ];
+  };
+in
+buildPythonPackage rec {
   pname = "cupy";
   version = "12.2.0";
 
@@ -32,27 +55,32 @@ in buildPythonPackage rec {
   # very short builds and a few extremely long ones, so setting both ends up
   # working nicely in practice.
   preConfigure = ''
-    export CUDA_PATH=${cudatoolkit}
     export CUPY_NUM_BUILD_JOBS="$NIX_BUILD_CORES"
     export CUPY_NUM_NVCC_THREADS="$NIX_BUILD_CORES"
   '';
 
   nativeBuildInputs = [
+    setuptools
+    wheel
     addOpenGLRunpath
     cython
+    cudaPackages.cuda_nvcc
   ];
 
-  LDFLAGS = "-L${cudatoolkit}/lib/stubs";
-
-  propagatedBuildInputs = [
-    cudatoolkit
+  buildInputs = [
+    cudatoolkit-joined
     cudnn
     cutensor
     nccl
+  ];
+
+  NVCC = "${lib.getExe cudaPackages.cuda_nvcc}"; # FIXME: splicing/buildPackages
+  CUDA_PATH = "${cudatoolkit-joined}";
+  LDFLAGS = "-L${cudaPackages.cuda_cudart}/lib/stubs";
+
+  propagatedBuildInputs = [
     fastrlock
     numpy
-    setuptools
-    wheel
   ];
 
   nativeCheckInputs = [