diff options
author | Samuel Ainsworth <skainsworth@gmail.com> | 2021-09-04 20:57:32 +0000 |
---|---|---|
committer | Samuel Ainsworth <skainsworth@gmail.com> | 2021-09-04 20:58:15 +0000 |
commit | 3ea5dbdd720abec9a48665a69e69e5486bfe0fa0 (patch) | |
tree | 9da35e2012e055e0a6ef749e84ed95a523f43734 /pkgs/development/python-modules/jaxlib | |
parent | 0ff986154ea140f256e47a78233579ad27d8387a (diff) |
python3Packages.jaxlib: add cudatoolkit_11 in propagatedBuildInputs
Diffstat (limited to 'pkgs/development/python-modules/jaxlib')
-rw-r--r-- | pkgs/development/python-modules/jaxlib/default.nix | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index 039c3deb8fc87..af0755bf9dd09 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -2,10 +2,13 @@ # backend will require some additional work. Those wheels are located here: # https://storage.googleapis.com/jax-releases/libtpu_releases.html. -# For future reference, the easiest way to test that the gpu is being used is: +# For future reference, the easiest way to test the GPU backend is to run # NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }" -# python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)" -# See https://github.com/google/jax/issues/971#issuecomment-508216439. +# python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'" +# python -c "from jax import random; random.PRNGKey(0)" +# See https://github.com/google/jax/issues/971#issuecomment-508216439. There's +# no convenient way to test the GPU backend in the derivation since the nix +# build environment blocks access to the GPU. { addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config, fetchPypi , fetchurl, isPy39, lib, stdenv @@ -66,8 +69,8 @@ buildPythonPackage rec { done ''; - # pip dependencies - propagatedBuildInputs = [ absl-py flatbuffers scipy ]; + # pip dependencies and optionally cudatoolkit. + propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11; pythonImportsCheck = [ "jaxlib" ]; |