about summary refs log tree commit diff
path: root/pkgs/top-level/cuda-packages.nix
diff options
context:
space:
mode:
authorAidan Gauland <aidalgol@fastmail.net>2022-06-21 15:36:41 +1200
committerAidan Gauland <aidalgol@fastmail.net>2022-06-24 13:02:02 +1200
commitd70b4df686a714f4a7f97bdf67eda1473f87707a (patch)
treec6a4876480dbe582fb10bc7a8edd1692467b0a0d /pkgs/top-level/cuda-packages.nix
parenta19f2c688bcd06d89bdc8848918eaff0e0991aa4 (diff)
tensorrt: init at 8.4.0.6
Add derivation for TensorRT 8, a high-performance deep learning interface SDK
from NVIDIA, which is at this point non-redistributable.  The current version
aldo requires CUDA 11, so this is left out of the cudaPackages_10* scopes.
Diffstat (limited to 'pkgs/top-level/cuda-packages.nix')
-rw-r--r--pkgs/top-level/cuda-packages.nix15
1 files changed, 13 insertions, 2 deletions
diff --git a/pkgs/top-level/cuda-packages.nix b/pkgs/top-level/cuda-packages.nix
index 211540260d10c..af8beb9b58c37 100644
--- a/pkgs/top-level/cuda-packages.nix
+++ b/pkgs/top-level/cuda-packages.nix
@@ -43,6 +43,16 @@ let
     };
   in { inherit cutensor; };
 
+  tensorrtExtension = final: prev: let
+    ### Tensorrt
+
+    inherit (final) cudaMajorMinorVersion cudaMajorVersion;
+
+    # TODO: Add derivations for TensorRT versions that support older CUDA versions.
+
+    tensorrt = final.callPackage ../development/libraries/science/math/tensorrt/8.nix { };
+  in { inherit tensorrt; };
+
   extraPackagesExtension = final: prev: {
 
     nccl = final.callPackage ../development/libraries/science/math/nccl { };
@@ -58,7 +68,7 @@ let
 
   };
 
-  composedExtension = composeManyExtensions [
+  composedExtension = composeManyExtensions ([
     extraPackagesExtension
     (import ../development/compilers/cudatoolkit/extension.nix)
     (import ../development/compilers/cudatoolkit/redist/extension.nix)
@@ -67,6 +77,7 @@ let
     (import ../test/cuda/cuda-samples/extension.nix)
     (import ../test/cuda/cuda-library-samples/extension.nix)
     cutensorExtension
-  ];
+  ] ++ (lib.optional (lib.strings.versionAtLeast cudaVersion "11.0") tensorrtExtension));
+  # We only package the current version of TensorRT, which requires CUDA 11.
 
 in (scope.overrideScope' composedExtension)