about summary refs log tree commit diff
path: root/pkgs/development/libraries/science/math/tensorrt/extension.nix
blob: b4f47a8969c142253c319b0448abf389b51da993 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
final: prev: let

  inherit (final) callPackage;
  inherit (prev) cudatoolkit cudaVersion lib pkgs;

  ### TensorRT

  buildTensorRTPackage = args:
    callPackage ./generic.nix { } args;

  toUnderscore = str: lib.replaceStrings ["."] ["_"] str;

  majorMinorPatch = str: lib.concatStringsSep "." (lib.take 3 (lib.splitVersion str));

  tensorRTPackages = with lib; let
    # Check whether a file is supported for our cuda version
    isSupported = fileData: elem cudaVersion fileData.supportedCudaVersions;
    # Return the first file that is supported. In practice there should only ever be one anyway.
    supportedFile = files: findFirst isSupported null files;
    # Supported versions with versions as keys and file as value
    supportedVersions = filterAttrs (version: file: file !=null ) (mapAttrs (version: files: supportedFile files) tensorRTVersions);
    # Compute versioned attribute name to be used in this package set
    computeName = version: "tensorrt_${toUnderscore version}";
    # Add all supported builds as attributes
    allBuilds = mapAttrs' (version: file: nameValuePair (computeName version) (buildTensorRTPackage (removeAttrs file ["fileVersionCuda"]))) supportedVersions;
    # Set the default attributes, e.g. tensorrt = tensorrt_8_4;
    defaultBuild = { "tensorrt" = if allBuilds ? ${computeName tensorRTDefaultVersion}
      then allBuilds.${computeName tensorRTDefaultVersion}
      else throw "tensorrt-${tensorRTDefaultVersion} does not support your cuda version ${cudaVersion}"; };
  in allBuilds // defaultBuild;

  tensorRTVersions = {
    "8.4.0" = [
      rec {
        fileVersionCuda = "11.6";
        fileVersionCudnn = "8.3";
        fullVersion = "8.4.0.6";
        sha256 = "sha256-DNgHHXF/G4cK2nnOWImrPXAkOcNW6Wy+8j0LRpAH/LQ=";
        tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
        supportedCudaVersions = [ "11.0" "11.1" "11.2" "11.3" "11.4" "11.5" "11.6" "11.7" ];
      }
      rec {
        fileVersionCuda = "10.2";
        fileVersionCudnn = "8.3";
        fullVersion = "8.4.0.6";
        sha256 = "sha256-aCzH0ZI6BrJ0v+e5Bnm7b8mNltA7NNuIa8qRKzAQv+I=";
        tarball = "TensorRT-${fullVersion}.Linux.x86_64-gnu.cuda-${fileVersionCuda}.cudnn${fileVersionCudnn}.tar.gz";
        supportedCudaVersions = [ "10.2" ];
      }
    ];
  };

  # Default attributes
  tensorRTDefaultVersion = {
    "10.2" = "8.4.0";
    "11.0" = "8.4.0";
    "11.1" = "8.4.0";
    "11.2" = "8.4.0";
    "11.3" = "8.4.0";
    "11.4" = "8.4.0";
    "11.5" = "8.4.0";
    "11.6" = "8.4.0";
    "11.7" = "8.4.0";
  }.${cudaVersion} or "8.4.0";

in tensorRTPackages