diff options
Diffstat (limited to 'pkgs/development/python-modules/jax/test-cuda.nix')
-rw-r--r-- | pkgs/development/python-modules/jax/test-cuda.nix | 32 |
1 files changed, 20 insertions, 12 deletions
diff --git a/pkgs/development/python-modules/jax/test-cuda.nix b/pkgs/development/python-modules/jax/test-cuda.nix index d156061f3849..5aca523f3177 100644 --- a/pkgs/development/python-modules/jax/test-cuda.nix +++ b/pkgs/development/python-modules/jax/test-cuda.nix @@ -1,17 +1,25 @@ -{ jax -, jaxlib -, pkgs +{ + jax, + jaxlib, + pkgs, }: -pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } '' - import jax - from jax import random +pkgs.writers.writePython3Bin "jax-test-cuda" + { + libraries = [ + jax + jaxlib + ]; + } + '' + import jax + from jax import random - assert jax.devices()[0].platform == "gpu" + assert jax.devices()[0].platform == "gpu" - rng = random.PRNGKey(0) - x = random.normal(rng, (100, 100)) - x @ x + rng = random.PRNGKey(0) + x = random.normal(rng, (100, 100)) + x @ x - print("success!") -'' + print("success!") + '' |