about summary refs log tree commit diff
path: root/pkgs/development/libraries/science/math/cutensor/generic.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/libraries/science/math/cutensor/generic.nix')
-rw-r--r--pkgs/development/libraries/science/math/cutensor/generic.nix29
1 files changed, 22 insertions, 7 deletions
diff --git a/pkgs/development/libraries/science/math/cutensor/generic.nix b/pkgs/development/libraries/science/math/cutensor/generic.nix
index c957fcdd99d4e..02fe13851620b 100644
--- a/pkgs/development/libraries/science/math/cutensor/generic.nix
+++ b/pkgs/development/libraries/science/math/cutensor/generic.nix
@@ -1,7 +1,11 @@
 { stdenv
 , lib
 , libPath
+, cuda_cudart
+, cudaMajorVersion
+, cuda_nvcc
 , cudatoolkit
+, libcublas
 , fetchurl
 , autoPatchelfHook
 , addOpenGLRunpath
@@ -17,7 +21,7 @@ let
 in
 
 stdenv.mkDerivation {
-  pname = "cudatoolkit-${cudatoolkit.majorVersion}-cutensor";
+  pname = "cutensor-cu${cudaMajorVersion}";
   inherit version;
 
   src = fetchurl {
@@ -32,20 +36,27 @@ stdenv.mkDerivation {
   nativeBuildInputs = [
     autoPatchelfHook
     addOpenGLRunpath
+    cuda_nvcc
   ];
 
   buildInputs = [
     stdenv.cc.cc.lib
-  ];
-
-  propagatedBuildInputs = [
-    cudatoolkit
+    cuda_cudart
+    libcublas
   ];
 
   # Set RUNPATH so that libcuda in /run/opengl-driver(-32)/lib can be found.
   # See the explanation in addOpenGLRunpath.
   installPhase = ''
     mkdir -p "$out" "$dev"
+
+    if [[ ! -d "${libPath}" ]] ; then
+      echo "Cutensor: ${libPath} does not exist, only found:" >&2
+      find "$(dirname ${libPath})"/ -maxdepth 1 >&2
+      echo "This cutensor release might not support your cudatoolkit version" >&2
+      exit 1
+    fi
+
     mv include "$dev"
     mv ${libPath} "$out/lib"
 
@@ -58,7 +69,7 @@ stdenv.mkDerivation {
   '';
 
   passthru = {
-    inherit cudatoolkit;
+    cudatoolkit = lib.warn "cutensor.passthru: cudaPackages.cudatoolkit is deprecated" cudatoolkit;
     majorVersion = lib.versions.major version;
   };
 
@@ -66,7 +77,11 @@ stdenv.mkDerivation {
     description = "cuTENSOR: A High-Performance CUDA Library For Tensor Primitives";
     homepage = "https://developer.nvidia.com/cutensor";
     sourceProvenance = with sourceTypes; [ binaryNativeCode ];
-    license = licenses.unfree;
+    license = licenses.unfreeRedistributable // {
+      shortName = "cuTENSOR EULA";
+      name = "cuTENSOR SUPPLEMENT TO SOFTWARE LICENSE AGREEMENT FOR NVIDIA SOFTWARE DEVELOPMENT KITS";
+      url = "https://docs.nvidia.com/cuda/cutensor/license.html";
+    };
     platforms = [ "x86_64-linux" ];
     maintainers = with maintainers; [ obsidian-systems-maintenance ];
   };