about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2023-12-01 17:00:22 +0000
committerConnor Baker <connor.baker@tweag.io>2023-12-04 18:49:00 +0000
commit5bf016e1e98accfd62c5bbd5a6dea3049af72595 (patch)
tree8b934541eb140473b0f9a97c59daca07dc47e0d2 /pkgs/development/python-modules/torch
parent5a54140dffa69cf172cbcb624ce2031a975acd98 (diff)
python3Packages.torch: enable cuDNN & NCCL only if available
Diffstat (limited to 'pkgs/development/python-modules/torch')
-rw-r--r--pkgs/development/python-modules/torch/default.nix19
1 files changed, 9 insertions, 10 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index 49ed033bd6a15..d18dd97df5b45 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -56,10 +56,7 @@
 
 let
   inherit (lib) attrsets lists strings trivial;
-  inherit (cudaPackages) cudaFlags cudnn;
-
-  # Some packages are not available on all platforms
-  nccl = cudaPackages.nccl or null;
+  inherit (cudaPackages) cudaFlags cudnn nccl;
 
   setBool = v: if v then "1" else "0";
 
@@ -212,10 +209,11 @@ in buildPythonPackage rec {
   # For more, see https://github.com/open-mpi/ompi/issues/7733#issuecomment-629806195.
   preConfigure = lib.optionalString cudaSupport ''
     export TORCH_CUDA_ARCH_LIST="${gpuTargetString}"
-    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
-    export CUDNN_LIB_DIR=${cudnn.lib}/lib
     export CUPTI_INCLUDE_DIR=${cudaPackages.cuda_cupti.dev}/include
     export CUPTI_LIBRARY_DIR=${cudaPackages.cuda_cupti.lib}/lib
+  '' + lib.optionalString (cudaSupport && cudaPackages ? cudnn) ''
+    export CUDNN_INCLUDE_DIR=${cudnn.dev}/include
+    export CUDNN_LIB_DIR=${cudnn.lib}/lib
   '' + lib.optionalString rocmSupport ''
     export ROCM_PATH=${rocmtoolkit_joined}
     export ROCM_SOURCE_DIR=${rocmtoolkit_joined}
@@ -273,7 +271,7 @@ in buildPythonPackage rec {
   PYTORCH_BUILD_VERSION = version;
   PYTORCH_BUILD_NUMBER = 0;
 
-  USE_NCCL = setBool (nccl != null);
+  USE_NCCL = setBool (cudaPackages ? nccl);
   USE_SYSTEM_NCCL = setBool useSystemNccl;                  # don't build pytorch's third_party NCCL
   USE_STATIC_NCCL = setBool useSystemNccl;
 
@@ -348,8 +346,6 @@ in buildPythonPackage rec {
       cuda_nvrtc.lib
       cuda_nvtx.dev
       cuda_nvtx.lib # -llibNVToolsExt
-      cudnn.dev
-      cudnn.lib
       libcublas.dev
       libcublas.lib
       libcufft.dev
@@ -360,7 +356,10 @@ in buildPythonPackage rec {
       libcusolver.lib
       libcusparse.dev
       libcusparse.lib
-    ] ++ lists.optionals (nccl != null) [
+    ] ++ lists.optionals (cudaPackages ? cudnn) [
+      cudnn.dev
+      cudnn.lib
+    ] ++ lists.optionals (useSystemNccl && cudaPackages ? nccl) [
       # Some platforms do not support NCCL (i.e., Jetson)
       nccl.dev # Provides nccl.h AND a static copy of NCCL!
     ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [