about summary refs log tree commit diff
path: root/pkgs/development/python-modules/tensorflow/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/tensorflow/default.nix')
-rw-r--r--pkgs/development/python-modules/tensorflow/default.nix51
1 files changed, 34 insertions, 17 deletions
diff --git a/pkgs/development/python-modules/tensorflow/default.nix b/pkgs/development/python-modules/tensorflow/default.nix
index d311edc188ad6..5f41420dffbca 100644
--- a/pkgs/development/python-modules/tensorflow/default.nix
+++ b/pkgs/development/python-modules/tensorflow/default.nix
@@ -116,7 +116,13 @@ let
   # cudaPackages.cudnn led to this:
   # https://github.com/tensorflow/tensorflow/issues/60398
   cudnnAttribute = "cudnn_8_6";
-  cudnn = cudaPackages.${cudnnAttribute};
+  cudnnMerged = symlinkJoin {
+    name = "cudnn-merged";
+    paths = [
+      (lib.getDev cudaPackages.${cudnnAttribute})
+      (lib.getLib cudaPackages.${cudnnAttribute})
+    ];
+  };
   gentoo-patches = fetchzip {
     url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2";
     hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs=";
@@ -130,19 +136,30 @@ let
 
   withTensorboard = (pythonOlder "3.6") || tensorboardSupport;
 
-  # FIXME: migrate to redist cudaPackages
-  cudatoolkit_joined = symlinkJoin {
-    name = "${cudatoolkit.name}-merged";
-    paths =
-      [
-        cudatoolkit.lib
-        cudatoolkit.out
-      ]
-      ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [
-        # for some reason some of the required libs are in the targets/x86_64-linux
-        # directory; not sure why but this works around it
-        "${cudatoolkit}/targets/${stdenv.system}"
-      ];
+  cudaComponents = with cudaPackages; [
+    (cuda_nvcc.__spliced.buildHost or cuda_nvcc)
+    (cuda_nvprune.__spliced.buildHost or cuda_nvprune)
+    cuda_cccl # block_load.cuh
+    cuda_cudart # cuda.h
+    cuda_cupti # cupti.h
+    cuda_nvcc # See https://github.com/google/jax/issues/19811
+    cuda_nvml_dev # nvml.h
+    cuda_nvtx # nvToolsExt.h
+    libcublas # cublas_api.h
+    libcufft # cufft.h
+    libcurand # curand.h
+    libcusolver # cusolver_common.h
+    libcusparse # cusparse.h
+  ];
+
+  cudatoolkitDevMerged = symlinkJoin {
+    name = "cuda-${cudaPackages.cudaVersion}-dev-merged";
+    paths = lib.concatMap (p: [
+      (lib.getBin p)
+      (lib.getDev p)
+      (lib.getLib p)
+      (lib.getOutput "static" p) # Makes for a very fat closure
+    ]) cudaComponents;
   };
 
   # Tensorflow expects bintools at hard-coded paths, e.g. /usr/bin/ar
@@ -321,7 +338,7 @@ let
       ]
       ++ lib.optionals cudaSupport [
         cudatoolkit
-        cudnn
+        cudnnMerged
       ]
       ++ lib.optionals mklSupport [ mkl ]
       ++ lib.optionals stdenv.isDarwin [
@@ -402,7 +419,7 @@ let
     TF_NEED_MPI = tfFeature cudaSupport;
 
     TF_NEED_CUDA = tfFeature cudaSupport;
-    TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkit_joined},${cudnn},${nccl}";
+    TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkitDevMerged},${cudnnMerged},${lib.getLib nccl}";
     TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities;
 
     # Needed even when we override stdenv: e.g. for ar
@@ -653,7 +670,7 @@ buildPythonPackage {
     find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do
       addOpenGLRunpath "$lib"
 
-      patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
+      patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnnMerged}/lib:${lib.getLib nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib"
     done
   '';