diff options
author | Samuel Ainsworth <skainsworth@gmail.com> | 2021-09-04 05:47:09 +0000 |
---|---|---|
committer | Samuel Ainsworth <skainsworth@gmail.com> | 2021-09-04 05:47:09 +0000 |
commit | 0ff986154ea140f256e47a78233579ad27d8387a (patch) | |
tree | 4d052331b81b2471c1fda9da88d5417458489c80 /pkgs/development/python-modules/jaxlib | |
parent | 6f44416cf2655b87378d3752eda00dfccb39dea6 (diff) |
python3Packages.jaxlib: add CUDA support
Diffstat (limited to 'pkgs/development/python-modules/jaxlib')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 66 |
1 files changed, 51 insertions, 15 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 240c5a7d6d0ee..039c3deb8fc87 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -1,16 +1,28 @@ -# For the moment we only support the CPU backend of jaxlib. GPU and TPU backends -# require some additional work. Their wheels are not located on PyPI. -# * CPU/GPU: https://storage.googleapis.com/jax-releases/jax_releases.html -# * TPU: https://storage.googleapis.com/jax-releases/libtpu_releases.html +# For the moment we only support the CPU and GPU backends of jaxlib. The TPU +# backend will require some additional work. Those wheels are located here: +# https://storage.googleapis.com/jax-releases/libtpu_releases.html. -{ autoPatchelfHook, buildPythonPackage, fetchPypi, isPy39, lib, stdenv +# For future reference, the easiest way to test that the gpu is being used is: +# NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" +# python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)" +# See https://github.com/google/jax/issues/971#issuecomment-508216439. + +{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config, fetchPypi +, fetchurl, isPy39, lib, stdenv # propagatedBuildInputs -, absl-py, flatbuffers, scipy +, absl-py, flatbuffers, scipy, cudatoolkit_11 +# Options: +, cudaSupport ? config.cudaSupport or false }: +assert cudaSupport -> lib.versionAtLeast cudatoolkit_11.version "11.1"; + +let + device = if cudaSupport then "gpu" else "cpu"; +in buildPythonPackage rec { pname = "jaxlib"; - version = "0.1.70"; + version = "0.1.71"; format = "wheel"; # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting @@ -18,23 +30,47 @@ buildPythonPackage rec { # version. disabled = !isPy39; - src = fetchPypi { - inherit pname version format; - dist = "cp39"; - python = "cp39"; - platform = "manylinux2010_x86_64"; - sha256 = "sha256-mytMTqoavpuRawj52MU5/iFj27SGlm8DaoQ5vd/3bss="; - }; + src = { + cpu = fetchurl { + url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; + sha256 = "sha256:0rqhs6qabydizlv5d3rb20dbv6612rr7dqfniy9r6h4kazdinsn6"; + }; + gpu = fetchurl { + url = "https://storage.googleapis.com/jax-releases/cuda111/jaxlib-${version}+cuda111-cp39-none-manylinux2010_x86_64.whl"; + sha256 = "sha256:065kyzjsk9m84d138p99iymdiiicm1qz8a3iwxz8rspl43rwrw89"; + }; + }.${device}; # Prebuilt wheels are dynamically linked against things that nix can't find. # Run `autoPatchelfHook` to automagically fix them. - nativeBuildInputs = [ autoPatchelfHook ]; + nativeBuildInputs = [ autoPatchelfHook ] ++ lib.optional cudaSupport addOpenGLRunpath; # Dynamic link dependencies buildInputs = [ stdenv.cc.cc ]; + # jaxlib contains shared libraries that open other shared libraries via dlopen + # and these implicit dependencies are not recognized by ldd or + # autoPatchelfHook. That means we need to sneak them into rpath. This step + # must be done after autoPatchelfHook and the automatic stripping of + # artifacts. autoPatchelfHook runs in postFixup and auto-stripping runs in the + # patchPhase. Dependencies: + # * libcudart.so.11.0 -> cudatoolkit_11.lib + # * libcuda.so.1 -> opengl driver in /run/opengl-driver/lib + preInstallCheck = lib.optional cudaSupport '' + shopt -s globstar + + addOpenGLRunpath $out/**/*.so + + for file in $out/**/*.so; do + rpath=$(patchelf --print-rpath $file) + patchelf --set-rpath "$rpath:${lib.makeLibraryPath [ cudatoolkit_11.lib ]}" $file + done + ''; + # pip dependencies propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + pythonImportsCheck = [ "jaxlib" ]; + meta = with lib; { description = "XLA library for JAX"; homepage = "https://github.com/google/jax"; |