diff options
author | Connor Baker <connor.baker@tweag.io> | 2023-11-09 10:28:19 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-09 10:28:19 -0500 |
commit | 47f07caec979e8986380fca526be8f3aa628dc8b (patch) | |
tree | 85c917c255d4df0a90b40d2ccdab120291cdd469 /pkgs/development | |
parent | 417c2051a15ee5fbae81e6e2e1096e2779de9f85 (diff) | |
parent | 2a42503192f2fcf77009915424a60a016a8364ff (diff) |
Merge pull request #266081 from ConnorBaker/fix/torch-jetson
python3Packages.torch: patch cpp_extension.py for Jetson support
Diffstat (limited to 'pkgs/development')
-rw-r--r-- | pkgs/development/python-modules/torch/default.nix | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix index 52ab9ee5b25df..7cd029879e30a 100644 --- a/pkgs/development/python-modules/torch/default.nix +++ b/pkgs/development/python-modules/torch/default.nix @@ -48,7 +48,10 @@ let inherit (lib) attrsets lists strings trivial; - inherit (cudaPackages) cudaFlags cudnn nccl; + inherit (cudaPackages) cudaFlags cudnn; + + # Some packages are not available on all platforms + nccl = cudaPackages.nccl or null; setBool = v: if v then "1" else "0"; @@ -178,6 +181,13 @@ in buildPythonPackage rec { 'message(FATAL_ERROR "Found NCCL header version and library version' \ 'message(WARNING "Found NCCL header version and library version' '' + # TODO(@connorbaker): Remove this patch after 2.1.0 lands. + + lib.optionalString cudaSupport '' + substituteInPlace torch/utils/cpp_extension.py \ + --replace \ + "'8.6', '8.9'" \ + "'8.6', '8.7', '8.9'" + '' # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc' # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header. + lib.optionalString (stdenv.isDarwin && lib.versionOlder stdenv.targetPlatform.darwinSdkVersion "11.0") '' @@ -253,6 +263,7 @@ in buildPythonPackage rec { PYTORCH_BUILD_VERSION = version; PYTORCH_BUILD_NUMBER = 0; + USE_NCCL = setBool (nccl != null); USE_SYSTEM_NCCL = setBool useSystemNccl; # don't build pytorch's third_party NCCL USE_STATIC_NCCL = setBool useSystemNccl; @@ -316,6 +327,8 @@ in buildPythonPackage rec { libcusolver.lib libcusparse.dev libcusparse.lib + ] ++ lists.optionals (nccl != null) [ + # Some platforms do not support NCCL (i.e., Jetson) nccl.dev # Provides nccl.h AND a static copy of NCCL! ] ++ lists.optionals (strings.versionOlder cudaVersion "11.8") [ cuda_nvprof.dev # <cuda_profiler_api.h> |