about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torch/default.nix')
-rw-r--r--pkgs/development/python-modules/torch/default.nix48
1 files changed, 29 insertions, 19 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index cc08339f6e0ff..bd6b1b2628379 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -26,7 +26,6 @@
 
   # tests.cudaAvailable:
   callPackage,
-  torchWithCuda,
 
   # Native build inputs
   cmake,
@@ -34,7 +33,6 @@
   which,
   pybind11,
   removeReferencesTo,
-  pythonRelaxDepsHook,
 
   # Build inputs
   numactl,
@@ -54,9 +52,9 @@
   cffi,
   click,
   typing-extensions,
-  # ROCm build and `torch.compile` requires `openai-triton`
+  # ROCm build and `torch.compile` requires `triton`
   tritonSupport ? (!stdenv.isDarwin),
-  openai-triton,
+  triton,
 
   # Unit tests
   hypothesis,
@@ -101,7 +99,7 @@ let
 
   setBool = v: if v then "1" else "0";
 
-  # https://github.com/pytorch/pytorch/blob/v2.0.1/torch/utils/cpp_extension.py#L1744
+  # https://github.com/pytorch/pytorch/blob/v2.4.0/torch/utils/cpp_extension.py#L1953
   supportedTorchCudaCapabilities =
     let
       real = [
@@ -121,6 +119,7 @@ let
         "8.7"
         "8.9"
         "9.0"
+        "9.0a"
       ];
       ptx = lists.map (x: "${x}+PTX") real;
     in
@@ -203,15 +202,19 @@ let
       ]);
     "MPI cudatoolkit does not match cudaPackages.cudatoolkit" =
       MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
+    # This used to be a deep package set comparison between cudaPackages and
+    # effectiveMagma.cudaPackages, making torch too strict in cudaPackages.
+    # In particular, this triggered warnings from cuda's `aliases.nix`
     "Magma cudaPackages does not match cudaPackages" =
-      cudaSupport && (effectiveMagma.cudaPackages != cudaPackages);
-    "Rocm support is currently broken because `rocmPackages.hipblaslt` is unpackaged. (2024-06-09)" = rocmSupport;
+      cudaSupport && (effectiveMagma.cudaPackages.cudaVersion != cudaPackages.cudaVersion);
+    "Rocm support is currently broken because `rocmPackages.hipblaslt` is unpackaged. (2024-06-09)" =
+      rocmSupport;
   };
 in
 buildPythonPackage rec {
   pname = "torch";
   # Don't forget to update torch-bin to the same version.
-  version = "2.3.1";
+  version = "2.4.0";
   pyproject = true;
 
   disabled = pythonOlder "3.8.0";
@@ -229,11 +232,16 @@ buildPythonPackage rec {
     repo = "pytorch";
     rev = "refs/tags/v${version}";
     fetchSubmodules = true;
-    hash = "sha256-vpgtOqzIDKgRuqdT8lB/g6j+oMIH1RPxdbjtlzZFjV8=";
+    hash = "sha256-s49rtarGNNFpnNG+kfJtZLE8ND53Ma201I0cOjeFSts=";
   };
 
   patches =
-    lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
+    [
+      # Allow setting PYTHON_LIB_REL_PATH with an environment variable.
+      # https://github.com/pytorch/pytorch/pull/128419
+      ./passthrough-python-lib-rel-path.patch
+    ]
+    ++ lib.optionals cudaSupport [ ./fix-cmake-cuda-toolkit.patch ]
     ++ lib.optionals (stdenv.isDarwin && stdenv.isx86_64) [
       # pthreadpool added support for Grand Central Dispatch in April
       # 2020. However, this relies on functionality (DISPATCH_APPLY_AUTO)
@@ -278,11 +286,12 @@ buildPythonPackage rec {
           'message(FATAL_ERROR "Found NCCL header version and library version' \
           'message(WARNING "Found NCCL header version and library version'
     ''
-    # Remove PyTorch's FindCUDAToolkit.cmake and to use CMake's default.
-    # We do not remove the entirety of cmake/Modules_CUDA_fix because we need FindCUDNN.cmake.
+    # Remove PyTorch's FindCUDAToolkit.cmake and use CMake's default.
+    # NOTE: Parts of pytorch rely on unmaintained FindCUDA.cmake with custom patches to support e.g.
+    # newer architectures (sm_90a). We do want to delete vendored patches, but have to keep them
+    # until https://github.com/pytorch/pytorch/issues/76082 is addressed
     + lib.optionalString cudaSupport ''
       rm cmake/Modules/FindCUDAToolkit.cmake
-      rm -rf cmake/Modules_CUDA_fix/{upstream,FindCUDA.cmake}
     ''
     # error: no member named 'aligned_alloc' in the global namespace; did you mean simply 'aligned_alloc'
     # This lib overrided aligned_alloc hence the error message. Tltr: his function is linkable but not in header.
@@ -374,6 +383,10 @@ buildPythonPackage rec {
   USE_SYSTEM_NCCL = USE_NCCL;
   USE_STATIC_NCCL = USE_NCCL;
 
+  # Set the correct Python library path, broken since
+  # https://github.com/pytorch/pytorch/commit/3d617333e
+  PYTHON_LIB_REL_PATH = "${placeholder "out"}/${python.sitePackages}";
+
   # Suppress a weird warning in mkl-dnn, part of ideep in pytorch
   # (upstream seems to have fixed this in the wrong place?)
   # https://github.com/intel/mkl-dnn/commit/8134d346cdb7fe1695a2aa55771071d455fae0bc
@@ -433,7 +446,6 @@ buildPythonPackage rec {
       which
       ninja
       pybind11
-      pythonRelaxDepsHook
       removeReferencesTo
     ]
     ++ lib.optionals cudaSupport (
@@ -466,9 +478,7 @@ buildPythonPackage rec {
         libcusolver
         libcusparse
       ]
-      ++ lists.optionals (cudaPackages ? cudnn) [
-        cudnn
-      ]
+      ++ lists.optionals (cudaPackages ? cudnn) [ cudnn ]
       ++ lists.optionals useSystemNccl [
         # Some platforms do not support NCCL (i.e., Jetson)
         nccl # Provides nccl.h AND a static copy of NCCL!
@@ -488,7 +498,7 @@ buildPythonPackage rec {
       CoreServices
       libobjc
     ]
-    ++ lib.optionals tritonSupport [ openai-triton ]
+    ++ lib.optionals tritonSupport [ triton ]
     ++ lib.optionals MPISupport [ mpi ]
     ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];
 
@@ -516,7 +526,7 @@ buildPythonPackage rec {
 
     # torch/csrc requires `pybind11` at runtime
     pybind11
-  ] ++ lib.optionals tritonSupport [ openai-triton ];
+  ] ++ lib.optionals tritonSupport [ triton ];
 
   propagatedCxxBuildInputs =
     [ ] ++ lib.optionals MPISupport [ mpi ] ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];