about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torch/bin.nix')
-rw-r--r--pkgs/development/python-modules/torch/bin.nix17
1 files changed, 9 insertions, 8 deletions
diff --git a/pkgs/development/python-modules/torch/bin.nix b/pkgs/development/python-modules/torch/bin.nix
index e2899c081e08b..4ecaac16be187 100644
--- a/pkgs/development/python-modules/torch/bin.nix
+++ b/pkgs/development/python-modules/torch/bin.nix
@@ -7,7 +7,7 @@
   python,
   pythonAtLeast,
   pythonOlder,
-  addOpenGLRunpath,
+  addDriverRunpath,
   callPackage,
   cudaPackages,
   future,
@@ -16,20 +16,19 @@
   pyyaml,
   requests,
   setuptools,
-  torch-bin,
   typing-extensions,
   sympy,
   jinja2,
   networkx,
   filelock,
-  openai-triton,
+  triton,
 }:
 
 let
   pyVerNoDot = builtins.replaceStrings [ "." ] [ "" ] python.pythonVersion;
   srcs = import ./binary-hashes.nix version;
   unsupported = throw "Unsupported system";
-  version = "2.3.1";
+  version = "2.4.0";
 in
 buildPythonPackage {
   inherit version;
@@ -44,7 +43,7 @@ buildPythonPackage {
   src = fetchurl srcs."${stdenv.system}-${pyVerNoDot}" or unsupported;
 
   nativeBuildInputs = lib.optionals stdenv.isLinux [
-    addOpenGLRunpath
+    addDriverRunpath
     autoPatchelfHook
     autoAddDriverRunpath
   ];
@@ -88,7 +87,7 @@ buildPythonPackage {
     jinja2
     networkx
     filelock
-  ] ++ lib.optionals (stdenv.isLinux && stdenv.isx86_64) [ openai-triton ];
+  ] ++ lib.optionals (stdenv.isLinux && stdenv.isx86_64) [ triton ];
 
   postInstall = ''
     # ONNX conversion
@@ -103,7 +102,9 @@ buildPythonPackage {
   #
   # This is a quick hack to add `libnvrtc` to the runpath so that torch can find
   # it when it is needed at runtime.
-  extraRunpaths = lib.optionals stdenv.hostPlatform.isLinux [ "${lib.getLib cudaPackages.cuda_nvrtc}/lib" ];
+  extraRunpaths = lib.optionals stdenv.hostPlatform.isLinux [
+    "${lib.getLib cudaPackages.cuda_nvrtc}/lib"
+  ];
   postPhases = lib.optionals stdenv.isLinux [ "postPatchelfPhase" ];
   postPatchelfPhase = ''
     while IFS= read -r -d $'\0' elf ; do
@@ -121,7 +122,7 @@ buildPythonPackage {
 
   pythonImportsCheck = [ "torch" ];
 
-  passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
+  passthru.tests = callPackage ./tests.nix { };
 
   meta = {
     description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";