about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torch
diff options
context:
space:
mode:
authorYann Hamdaoui <yann.hamdaoui@tweag.io>2024-01-17 16:32:24 +0100
committerYann Hamdaoui <yann.hamdaoui@tweag.io>2024-03-15 15:54:21 +0100
commit63746cac08fe242003947167550d28ebc182bf77 (patch)
tree54e1172b75850dfdd823f7fa94601e9e495c402a /pkgs/development/python-modules/torch
parent6a9c892aec57608b49c5ffc524629a3550e8efe6 (diff)
cudaPackages: generalize and refactor setup hook
This PR refactor CUDA setup hooks, and in particular
autoAddOpenGLRunpath and autoAddCudaCompatRunpathHook, that were using a
lot of code in common (in fact, I introduced the latter by copy pasting
most of the bash script of the former). This is not satisfying for
maintenance, as a recent patch showed, because we need to duplicate
changes to both hooks.

This commit abstract the common part in a single shell script that
applies a generic patch action to every elf file in the output. For
autoAddOpenGLRunpath the action is just addOpenGLRunpath (now
addDriverRunpath), and is few line function for
autoAddCudaCompatRunpathHook.

Doing so, we also takes the occasion to use the newer addDriverRunpath
instead of the previous addOpenGLRunpath, and rename the CUDA hook to
reflect that as well.

Co-Authored-By: Connor Baker <connor.baker@tweag.io>
Diffstat (limited to 'pkgs/development/python-modules/torch')
-rw-r--r--pkgs/development/python-modules/torch/bin.nix2
-rw-r--r--pkgs/development/python-modules/torch/default.nix2
2 files changed, 2 insertions, 2 deletions
diff --git a/pkgs/development/python-modules/torch/bin.nix b/pkgs/development/python-modules/torch/bin.nix
index 0bb415574e392..bee32b6163453 100644
--- a/pkgs/development/python-modules/torch/bin.nix
+++ b/pkgs/development/python-modules/torch/bin.nix
@@ -40,7 +40,7 @@ in buildPythonPackage {
   nativeBuildInputs = lib.optionals stdenv.isLinux [
     addOpenGLRunpath
     autoPatchelfHook
-    cudaPackages.autoAddOpenGLRunpathHook
+    cudaPackages.autoAddDriverRunpath
   ];
 
   buildInputs = lib.optionals stdenv.isLinux (with cudaPackages; [
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index d6c51904bd9df..10eecd1de99b7 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -338,7 +338,7 @@ in buildPythonPackage rec {
     pythonRelaxDepsHook
     removeReferencesTo
   ] ++ lib.optionals cudaSupport (with cudaPackages; [
-    autoAddOpenGLRunpathHook
+    autoAddDriverRunpath
     cuda_nvcc
   ])
   ++ lib.optionals rocmSupport [ rocmtoolkit_joined ];