about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torch/default.nix')
-rw-r--r--pkgs/development/python-modules/torch/default.nix49
1 files changed, 19 insertions, 30 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index 9597a047bdb48..cc08339f6e0ff 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -301,11 +301,11 @@ buildPythonPackage rec {
   preConfigure =
     lib.optionalString cudaSupport ''
       export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
-      export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
-      export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
+      export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include
+      export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib
     ''
     + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
-      export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
+      export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include
       export CUDNN_LIB_DIR=${cudnn.lib}/lib
     ''
     + lib.optionalString rocmSupport ''
@@ -453,42 +453,31 @@ buildPythonPackage rec {
     ++ lib.optionals cudaSupport (
       with cudaPackages;
       [
-        cuda_cccl.dev # <thrust/*>
-        cuda_cudart.dev # cuda_runtime.h and libraries
-        cuda_cudart.lib
-        cuda_cudart.static
-        cuda_cupti.dev # For kineto
-        cuda_cupti.lib # For kineto
-        cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
-        cuda_nvml_dev.dev # <nvml.h>
-        cuda_nvrtc.dev
-        cuda_nvrtc.lib
-        cuda_nvtx.dev
-        cuda_nvtx.lib # -llibNVToolsExt
-        libcublas.dev
-        libcublas.lib
-        libcufft.dev
-        libcufft.lib
-        libcurand.dev
-        libcurand.lib
-        libcusolver.dev
-        libcusolver.lib
-        libcusparse.dev
-        libcusparse.lib
+        cuda_cccl # <thrust/*>
+        cuda_cudart # cuda_runtime.h and libraries
+        cuda_cupti # For kineto
+        cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too
+        cuda_nvml_dev # <nvml.h>
+        cuda_nvrtc
+        cuda_nvtx # -llibNVToolsExt
+        libcublas
+        libcufft
+        libcurand
+        libcusolver
+        libcusparse
       ]
       ++ lists.optionals (cudaPackages ? cudnn) [
-        cudnn.dev
-        cudnn.lib
+        cudnn
       ]
       ++ lists.optionals useSystemNccl [
         # Some platforms do not support NCCL (i.e., Jetson)
-        nccl.dev # Provides nccl.h AND a static copy of NCCL!
+        nccl # Provides nccl.h AND a static copy of NCCL!
       ]
       ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
-        cuda_nvprof.dev # <cuda_profiler_api.h>
+        cuda_nvprof # <cuda_profiler_api.h>
       ]
       ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [
-        cuda_profiler_api.dev # <cuda_profiler_api.h>
+        cuda_profiler_api # <cuda_profiler_api.h>
       ]
     )
     ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ]