diff options
Diffstat (limited to 'pkgs/development/python-modules/tensorflow/default.nix')
-rw-r--r-- | pkgs/development/python-modules/tensorflow/default.nix | 51 |
1 files changed, 34 insertions, 17 deletions
diff --git a/pkgs/development/python-modules/tensorflow/default.nix b/pkgs/development/python-modules/tensorflow/default.nix index d311edc188ad6..5f41420dffbca 100644 --- a/pkgs/development/python-modules/tensorflow/default.nix +++ b/pkgs/development/python-modules/tensorflow/default.nix @@ -116,7 +116,13 @@ let # cudaPackages.cudnn led to this: # https://github.com/tensorflow/tensorflow/issues/60398 cudnnAttribute = "cudnn_8_6"; - cudnn = cudaPackages.${cudnnAttribute}; + cudnnMerged = symlinkJoin { + name = "cudnn-merged"; + paths = [ + (lib.getDev cudaPackages.${cudnnAttribute}) + (lib.getLib cudaPackages.${cudnnAttribute}) + ]; + }; gentoo-patches = fetchzip { url = "https://dev.gentoo.org/~perfinion/patches/tensorflow-patches-2.12.0.tar.bz2"; hash = "sha256-SCRX/5/zML7LmKEPJkcM5Tebez9vv/gmE4xhT/jyqWs="; @@ -130,19 +136,30 @@ let withTensorboard = (pythonOlder "3.6") || tensorboardSupport; - # FIXME: migrate to redist cudaPackages - cudatoolkit_joined = symlinkJoin { - name = "${cudatoolkit.name}-merged"; - paths = - [ - cudatoolkit.lib - cudatoolkit.out - ] - ++ lib.optionals (lib.versionOlder cudatoolkit.version "11") [ - # for some reason some of the required libs are in the targets/x86_64-linux - # directory; not sure why but this works around it - "${cudatoolkit}/targets/${stdenv.system}" - ]; + cudaComponents = with cudaPackages; [ + (cuda_nvcc.__spliced.buildHost or cuda_nvcc) + (cuda_nvprune.__spliced.buildHost or cuda_nvprune) + cuda_cccl # block_load.cuh + cuda_cudart # cuda.h + cuda_cupti # cupti.h + cuda_nvcc # See https://github.com/google/jax/issues/19811 + cuda_nvml_dev # nvml.h + cuda_nvtx # nvToolsExt.h + libcublas # cublas_api.h + libcufft # cufft.h + libcurand # curand.h + libcusolver # cusolver_common.h + libcusparse # cusparse.h + ]; + + cudatoolkitDevMerged = symlinkJoin { + name = "cuda-${cudaPackages.cudaVersion}-dev-merged"; + paths = lib.concatMap (p: [ + (lib.getBin p) + (lib.getDev p) + (lib.getLib p) + (lib.getOutput "static" p) # Makes for a very fat closure + ]) cudaComponents; }; # Tensorflow expects bintools at hard-coded paths, e.g. /usr/bin/ar @@ -321,7 +338,7 @@ let ] ++ lib.optionals cudaSupport [ cudatoolkit - cudnn + cudnnMerged ] ++ lib.optionals mklSupport [ mkl ] ++ lib.optionals stdenv.isDarwin [ @@ -402,7 +419,7 @@ let TF_NEED_MPI = tfFeature cudaSupport; TF_NEED_CUDA = tfFeature cudaSupport; - TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkit_joined},${cudnn},${nccl}"; + TF_CUDA_PATHS = lib.optionalString cudaSupport "${cudatoolkitDevMerged},${cudnnMerged},${lib.getLib nccl}"; TF_CUDA_COMPUTE_CAPABILITIES = lib.concatStringsSep "," cudaCapabilities; # Needed even when we override stdenv: e.g. for ar @@ -653,7 +670,7 @@ buildPythonPackage { find $out -type f \( -name '*.so' -or -name '*.so.*' \) | while read lib; do addOpenGLRunpath "$lib" - patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnn}/lib:${nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" + patchelf --set-rpath "${cudatoolkit}/lib:${cudatoolkit.lib}/lib:${cudnnMerged}/lib:${lib.getLib nccl}/lib:$(patchelf --print-rpath "$lib")" "$lib" done ''; |