summary refs log tree commit diff
path: root/pkgs/development
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2023-11-09 10:28:19 -0500
committerGitHub <noreply@github.com>2023-11-09 10:28:19 -0500
commit47f07caec979e8986380fca526be8f3aa628dc8b (patch)
tree85c917c255d4df0a90b40d2ccdab120291cdd469 /pkgs/development
parent417c2051a15ee5fbae81e6e2e1096e2779de9f85 (diff)
parent2a42503192f2fcf77009915424a60a016a8364ff (diff)
Merge pull request #266081 from ConnorBaker/fix/torch-jetson
python3Packages.torch: patch cpp_extension.py for Jetson support
Diffstat (limited to 'pkgs/development')
-rw-r--r--pkgs/development/python-modules/torch/default.nix15
1 files changed, 14 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index 52ab9ee5b25df..7cd029879e30a 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -48,7 +48,10 @@
 
 let
   inherit (lib) attrsets lists strings trivial;
-  inherit (cudaPackages) cudaFlags cudnn nccl;
+  inherit (cudaPackages) cudaFlags cudnn;
+
+  # Some packages are not available on all platforms
+  nccl = cudaPackages.nccl or null;
 
   setBool = v: if v then "1" else "0";
 
@@ -178,6 +181,13 @@ in buildPythonPackage rec {
         'message(FATAL_ERROR "Found NCCL header version and library version' \
         'message(WARNING "Found NCCL header version and library version'
   ''
+  # TODO(@connorbaker): Remove this patch after 2.1.0 lands.
+  + lib.optionalString cudaSupport ''
+    substituteInPlace torch/utils/cpp_extension.py \
+      --replace \
+        "'8.6', '8.9'" \
+        "'8.6', '8.7', '8.9'"
+  ''
   # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
   # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
   + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") ''
@@ -253,6 +263,7 @@ in buildPythonPackage rec {
   PYTORCH_BUILD_VERSION = version;
   PYTORCH_BUILD_NUMBER = 0;
 
+  USE_NCCL = setBool (nccl != null);
   USE_SYSTEM_NCCL = setBool useSystemNccl;                  # don't build pytorch's third_party NCCL
   USE_STATIC_NCCL = setBool useSystemNccl;
 
@@ -316,6 +327,8 @@ in buildPythonPackage rec {
       libcusolver.lib
       libcusparse.dev
       libcusparse.lib
+    ] ++ lists.optionals (nccl != null) [
+      # Some platforms do not support NCCL (i.e., Jetson)
       nccl.dev # Provides nccl.h AND a static copy of NCCL!
     ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [
       cuda_nvprof.dev # <cuda_profiler_api.h>