about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/test-cuda.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/jax/test-cuda.nix')
-rw-r--r--pkgs/development/python-modules/jax/test-cuda.nix32
1 files changed, 20 insertions, 12 deletions
diff --git a/pkgs/development/python-modules/jax/test-cuda.nix b/pkgs/development/python-modules/jax/test-cuda.nix
index d156061f3849..5aca523f3177 100644
--- a/pkgs/development/python-modules/jax/test-cuda.nix
+++ b/pkgs/development/python-modules/jax/test-cuda.nix
@@ -1,17 +1,25 @@
-{ jax
-, jaxlib
-, pkgs
+{
+  jax,
+  jaxlib,
+  pkgs,
 }:
 
-pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
-  import jax
-  from jax import random
+pkgs.writers.writePython3Bin "jax-test-cuda"
+  {
+    libraries = [
+      jax
+      jaxlib
+    ];
+  }
+  ''
+    import jax
+    from jax import random
 
-  assert jax.devices()[0].platform == "gpu"
+    assert jax.devices()[0].platform == "gpu"
 
-  rng = random.PRNGKey(0)
-  x = random.normal(rng, (100, 100))
-  x @ x
+    rng = random.PRNGKey(0)
+    x = random.normal(rng, (100, 100))
+    x @ x
 
-  print("success!")
-''
+    print("success!")
+  ''