about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib
diff options
context:
space:
mode:
authorSamuel Ainsworth <skainsworth@gmail.com>2021-09-04 20:57:32 +0000
committerSamuel Ainsworth <skainsworth@gmail.com>2021-09-04 20:58:15 +0000
commit3ea5dbdd720abec9a48665a69e69e5486bfe0fa0 (patch)
tree9da35e2012e055e0a6ef749e84ed95a523f43734 /pkgs/development/python-modules/jaxlib
parent0ff986154ea140f256e47a78233579ad27d8387a (diff)
python3Packages.jaxlib: add cudatoolkit_11 in propagatedBuildInputs
Diffstat (limited to 'pkgs/development/python-modules/jaxlib')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix13
1 files changed, 8 insertions, 5 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
index 039c3deb8fc87..af0755bf9dd09 100644
--- a/pkgs/development/python-modules/jaxlib/default.nix
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -2,10 +2,13 @@
 # backend will require some additional work. Those wheels are located here:
 # https://storage.googleapis.com/jax-releases/libtpu_releases.html.
 
-# For future reference, the easiest way to test that the gpu is being used is:
+# For future reference, the easiest way to test the GPU backend is to run
 #   NIX_PATH=.. nix-shell -p python3 python3Packages.jax "python3Packages.jaxlib.override { cudaSupport = true; }"
-#   python -c "from jax.lib import xla_bridge; print(xla_bridge.get_backend().platform)"
-# See https://github.com/google/jax/issues/971#issuecomment-508216439.
+#   python -c "from jax.lib import xla_bridge; assert xla_bridge.get_backend().platform == 'gpu'"
+#   python -c "from jax import random; random.PRNGKey(0)"
+# See https://github.com/google/jax/issues/971#issuecomment-508216439. There's
+# no convenient way to test the GPU backend in the derivation since the nix
+# build environment blocks access to the GPU.
 
 { addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config, fetchPypi
 , fetchurl, isPy39, lib, stdenv
@@ -66,8 +69,8 @@ buildPythonPackage rec {
     done
   '';
 
-  # pip dependencies
-  propagatedBuildInputs = [ absl-py flatbuffers scipy ];
+  # pip dependencies and optionally cudatoolkit.
+  propagatedBuildInputs = [ absl-py flatbuffers scipy ] ++ lib.optional cudaSupport cudatoolkit_11;
 
   pythonImportsCheck = [ "jaxlib" ];