diff options
Diffstat (limited to 'pkgs/development/python-modules/torch/default.nix')
-rw-r--r-- | pkgs/development/python-modules/torch/default.nix | 49 |
1 files changed, 19 insertions, 30 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 9597a047bdb48..cc08339f6e0ff 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -301,11 +301,11 @@ buildPythonPackage rec { preConfigure = lib.optionalString cudaSupport '' export TORCH_CUDA_ARCH_LIST="${gpuTargetString}" - export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include - export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib + export CUPTI_INCLUDE_DIR=${lib.getDev cudaPackages.cuda_cupti}/include + export CUPTI_LIBRARY_DIR=${lib.getLib cudaPackages.cuda_cupti}/lib '' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) '' - export CUDNN_INCLUDE_DIR=${cudnn.dev}/include + export CUDNN_INCLUDE_DIR=${lib.getLib cudnn}/include export CUDNN_LIB_DIR=${cudnn.lib}/lib '' + lib.optionalString rocmSupport '' @@ -453,42 +453,31 @@ buildPythonPackage rec { ++ lib.optionals cudaSupport ( with cudaPackages; [ - cuda_cccl.dev # <thrust/*> - cuda_cudart.dev # cuda_runtime.h and libraries - cuda_cudart.lib - cuda_cudart.static - cuda_cupti.dev # For kineto - cuda_cupti.lib # For kineto - cuda_nvcc.dev # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too - cuda_nvml_dev.dev # <nvml.h> - cuda_nvrtc.dev - cuda_nvrtc.lib - cuda_nvtx.dev - cuda_nvtx.lib # -llibNVToolsExt - libcublas.dev - libcublas.lib - libcufft.dev - libcufft.lib - libcurand.dev - libcurand.lib - libcusolver.dev - libcusolver.lib - libcusparse.dev - libcusparse.lib + cuda_cccl # <thrust/*> + cuda_cudart # cuda_runtime.h and libraries + cuda_cupti # For kineto + cuda_nvcc # crt/host_config.h; even though we include this in nativeBuildinputs, it's needed here too + cuda_nvml_dev # <nvml.h> + cuda_nvrtc + cuda_nvtx # -llibNVToolsExt + libcublas + libcufft + libcurand + libcusolver + libcusparse ] ++ lists.optionals (cudaPackages ? cudnn) [ - cudnn.dev - cudnn.lib + cudnn ] ++ lists.optionals useSystemNccl [ # Some platforms do not support NCCL (i.e., Jetson) - nccl.dev # Provides nccl.h AND a static copy of NCCL! + nccl # Provides nccl.h AND a static copy of NCCL! ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ - cuda_nvprof.dev # <cuda_profiler_api.h> + cuda_nvprof # <cuda_profiler_api.h> ] ++ lists.optionals (strings.versionAtLeast cudaVersion "11.8") [ - cuda_profiler_api.dev # <cuda_profiler_api.h> + cuda_profiler_api # <cuda_profiler_api.h> ] ) ++ lib.optionals rocmSupport [ rocmPackages.llvm.openmp ] |