about summary refs log tree commit diff
path: root/pkgs/development/libraries/science/math/tensorrt/extension.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/libraries/science/math/tensorrt/extension.nix')
-rw-r--r--pkgs/development/libraries/science/math/tensorrt/extension.nix28
1 files changed, 22 insertions, 6 deletions
diff --git a/pkgs/development/libraries/science/math/tensorrt/extension.nix b/pkgs/development/libraries/science/math/tensorrt/extension.nix
index b4018c6cc284d..ffd9b672684cb 100644
--- a/pkgs/development/libraries/science/math/tensorrt/extension.nix
+++ b/pkgs/development/libraries/science/math/tensorrt/extension.nix
@@ -17,16 +17,32 @@ final: prev: let
     isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
     # Return the first file that is supported. In practice there should only ever be one anyway.
     supportedFile = files: findFirst isSupported null files;
-    # Supported versions with versions as keys and file as value
-    supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
+
     # Compute versioned attribute name to be used in this package set
     computeName = version: "tensorrt_${toUnderscore version}";
+
+    # Supported versions with versions as keys and file as value
+    supportedVersions = lib.recursiveUpdate
+      {
+        tensorrt = {
+          enable = false;
+          fileVersionCuda = null;
+          fileVersionCudnn = null;
+          fullVersion = "0.0.0";
+          sha256 = null;
+          tarball = null;
+          supportedCudaVersions = [ ];
+        };
+      }
+      (mapAttrs' (version: attrs: nameValuePair (computeName version) attrs)
+        (filterAttrs (version: file: file != null) (mapAttrs (version: files: supportedFile files) tensorRTVersions)));
+
     # Add all supported builds as attributes
-    allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
+    allBuilds = mapAttrs (name: file: buildTensorRTPackage (removeAttrs file ["fileVersionCuda"])) supportedVersions;
+
     # Set the default attributes, e.g. tensorrt = tensorrt_8_4;
-    defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion}
-      then allBuilds.${computeName tensorRTDefaultVersion}
-      else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
+    defaultName = computeName tensorRTDefaultVersion;
+    defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; };
   in {
     inherit buildTensorRTPackage;
   } // allBuilds // defaultBuild;