about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/default.nix')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix14
1 files changed, 7 insertions, 7 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index c70ab0ac2b327..4293cc781cb89 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -44,17 +44,17 @@
 , config
   # CUDA flags:
 , cudaSupport ? config.cudaSupport
-, cudaPackages ? {}
+, cudaPackagesGoogle
 
   # MKL:
 , mklSupport ? true
 }:
 
 let
-  inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;
+  inherit (cudaPackagesGoogle) backendStdenv cudatoolkit cudaFlags cudnn nccl;
 
   pname = "jaxlib";
-  version = "0.4.20";
+  version = "0.4.21";
 
   meta = with lib; {
     description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
@@ -150,7 +150,7 @@ let
       repo = "jax";
       # google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
       rev = "refs/tags/${pname}-v${version}";
-      hash = "sha256-WLYXUtchOaA6SGnKuVhN9CmV06xMCLQTEuEtL13ttZU=";
+      hash = "sha256-CMsW/t4/itJxN4pST8EKkN0ooHWdjRnLs073FwbXRJM=";
     };
 
     nativeBuildInputs = [
@@ -263,10 +263,10 @@ let
       ];
 
       sha256 = (if cudaSupport then {
-        x86_64-linux = "sha256-QczClHxHElLZCqIZlHc3z3DXJ7rZQJaMs2XIb+lxarI=";
+        x86_64-linux = "sha256-TgIH7r1IXNkbOFSXvaKVbU9kL+TuQqxVrBge7iv2ykQ=";
       } else {
-        x86_64-linux = "sha256-mqiJe4u0NYh1PKCbQfbo0U2e9/kYiBqj98d+BPHFSxQ=";
-        aarch64-linux = "sha256-EuLqamVBJ+qoVMCFIYUT846AghltZolfLGdtO9UeXSM=";
+        x86_64-linux = "sha256-on14CAolJ3mvJmKxX2PE21BsYOJJFUSQuUOnOuVR2GQ=";
+        aarch64-linux = "sha256-2tcIiQlMUKMc+juCy+dt37s+lFqr2pcVizCyYkkQtOM=";
       }).${stdenv.system} or (throw "jaxlib: unsupported system: ${stdenv.system}");
     };