diff options
Diffstat (limited to 'pkgs/development/python-modules/jaxlib/bin.nix')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/bin.nix | 46 |
1 files changed, 34 insertions, 12 deletions
diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index f597eeacfced4..6e70b24f67da7 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -13,22 +13,35 @@ # * https://github.com/google/jax/issues/971#issuecomment-508216439 # * https://github.com/google/jax/issues/5723#issuecomment-913038780 -{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config -, fetchurl, isPy39, lib, stdenv -# propagatedBuildInputs -, absl-py, flatbuffers, scipy, cudatoolkit_11 -# Options: +{ absl-py +, addOpenGLRunpath +, autoPatchelfHook +, buildPythonPackage +, config +, cudatoolkit_11 +, cudnn +, fetchurl +, flatbuffers +, isPy39 +, lib +, scipy +, stdenv + # Options: , cudaSupport ? config.cudaSupport or false }: +# Note that these values are tied to the specific version of the GPU wheel that +# we fetch. When updating, try to go for the latest possible versions that are +# still compatible with the cudatoolkit and cudnn versions available in nixpkgs. assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; +assert cudaSupport -> lib.versionAtLeast cudnn.version "8.0.5"; let device = if cudaSupport then "gpu" else "cpu"; in buildPythonPackage rec { pname = "jaxlib"; - version = "0.1.71"; + version = "0.3.0"; format = "wheel"; # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting @@ -36,14 +49,23 @@ buildPythonPackage rec { # version. disabled = !isPy39; + # Find new releases at https://storage.googleapis.com/jax-releases. src = { cpu = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6"; + sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01"; }; gpu = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89"; + # Note that there's also a release targeting cuDNN 8.2, but unfortunately + # we don't yet have that packaged at the time of writing (02/03/2022). + # Check pkgs/development/libraries/science/math/cudnn/default.nix for more + # details. + url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; + sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8"; + + # This is what the cuDNN 8.2 download looks like for future reference: + # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; + # sha256 = "000mnm2masm3sx3haddcmgw43j4gxa3m4fcm14p9nb8dnncjkgpb"; }; }.${device}; @@ -71,7 +93,7 @@ buildPythonPackage rec { rpath=$(patchelf --print-rpath $file) # For some reason `makeLibraryPath` on `cudatoolkit_11` maps to # <cudatoolkit_11.lib>/lib which is different from <cudatoolkit_11>/lib. - patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file + patchelf --set-rpath "$rpath:${cudatoolkit_11}/lib:${lib.makeLibraryPath [ cudatoolkit_11.lib cudnn ]}" $file done ''; @@ -82,8 +104,8 @@ buildPythonPackage rec { meta = with lib; { description = "XLA library for JAX"; - homepage = "https://github.com/google/jax"; - license = licenses.asl20; + homepage = "https://github.com/google/jax"; + license = licenses.asl20; maintainers = with maintainers; [ samuela ]; platforms = [ "x86_64-linux" ]; }; |