about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/bin.nix46
1 files changed, 34 insertions, 12 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix
index f597eeacfced4..6e70b24f67da7 100644
--- a/pkgs/development/python-modules/jaxlib/bin.nix
+++ b/pkgs/development/python-modules/jaxlib/bin.nix
@@ -13,22 +13,35 @@
 #   * https://github.com/google/jax/issues/971#issuecomment-508216439
 #   * https://github.com/google/jax/issues/5723#issuecomment-913038780
 
-{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config
-, fetchurl, isPy39, lib, stdenv
-# propagatedBuildInputs
-, absl-py, flatbuffers, scipy, cudatoolkit_11
-# Options:
+{ absl-py
+, addOpenGLRunpath
+, autoPatchelfHook
+, buildPythonPackage
+, config
+, cudatoolkit_11
+, cudnn
+, fetchurl
+, flatbuffers
+, isPy39
+, lib
+, scipy
+, stdenv
+  # Options:
 , cudaSupport ? config.cudaSupport or false
 }:
 
+# Note that these values are tied to the specific version of the GPU wheel that
+# we fetch. When updating, try to go for the latest possible versions that are
+# still compatible with the cudatoolkit and cudnn versions available in nixpkgs.
 assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1";
+assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5";
 
 let
   device = if cudaSupport then "gpu" else "cpu";
 in
 buildPythonPackage rec {
   pname = "jaxlib";
-  version = "0.1.71";
+  version = "0.3.0";
   format = "wheel";
 
   # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
@@ -36,14 +49,23 @@ buildPythonPackage rec {
   # version.
   disabled = !isPy39;
 
+  # Find new releases at https://storage.googleapis.com/jax-releases.
   src = {
     cpu = fetchurl {
       url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6";
+      sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01";
     };
     gpu = fetchurl {
-      url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl";
-      sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89";
+      # Note that there's also a release targeting cuDNN 8.2, but unfortunately
+      # we don't yet have that packaged at the time of writing (02/03/2022).
+      # Check pkgs/development/libraries/science/math/cudnn/default.nix for more
+      # details.
+      url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl";
+      sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8";
+
+      # This is what the cuDNN 8.2 download looks like for future reference:
+      # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl";
+      # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb";
     };
   }.${device};
 
@@ -71,7 +93,7 @@ buildPythonPackage rec {
       rpath=$(patchelf --print-rpath $file)
       # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to
       # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib.
-      patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file
+      patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib cudnn ]}" $file
     done
   '';
 
@@ -82,8 +104,8 @@ buildPythonPackage rec {
 
   meta = with lib; {
     description = "XLA library for JAX";
-    homepage    = "https://github.com/google/jax";
-    license     = licenses.asl20;
+    homepage = "https://github.com/google/jax";
+    license = licenses.asl20;
     maintainers = with maintainers; [ samuela ];
     platforms = [ "x86_64-linux" ];
   };