diff options
Diffstat (limited to 'pkgs/development/libraries/science/math/tensorrt/extension.nix')
-rw-r--r-- | pkgs/development/libraries/science/math/tensorrt/extension.nix | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/pkgs/development/libraries/science/math/tensorrt/extension.nix b/pkgs/development/libraries/science/math/tensorrt/extension.nix index b4018c6cc284d..ffd9b672684cb 100644 --- a/pkgs/development/libraries/science/math/tensorrt/extension.nix +++ b/pkgs/development/libraries/science/math/tensorrt/extension.nix @@ -17,16 +17,32 @@ final: prev: let isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions; # Return the first file that is supported. In practice there should only ever be one anyway. supportedFile = files: findFirst isSupported null files; - # Supported versions with versions as keys and file as value - supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions); + # Compute versioned attribute name to be used in this package set computeName = version: "tensorrt_${toUnderscore version}"; + + # Supported versions with versions as keys and file as value + supportedVersions = lib.recursiveUpdate + { + tensorrt = { + enable = false; + fileVersionCuda = null; + fileVersionCudnn = null; + fullVersion = "0.0.0"; + sha256 = null; + tarball = null; + supportedCudaVersions = [ ]; + }; + } + (mapAttrs' (version: attrs: nameValuePair (computeName version) attrs) + (filterAttrs (version: file: file != null) (mapAttrs (version: files: supportedFile files) tensorRTVersions))); + # Add all supported builds as attributes - allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions; + allBuilds = mapAttrs (name: file: buildTensorRTPackage (removeAttrs file ["fileVersionCuda"])) supportedVersions; + # Set the default attributes, e.g. tensorrt = tensorrt_8_4; - defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion} - then allBuilds.${computeName tensorRTDefaultVersion} - else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; }; + defaultName = computeName tensorRTDefaultVersion; + defaultBuild = lib.optionalAttrs (allBuilds ? ${defaultName}) { tensorrt = allBuilds.${computeName tensorRTDefaultVersion}; }; in { inherit buildTensorRTPackage; } // allBuilds // defaultBuild; |