blob: d156061f38495cf3003fa8a491516b1b1ca4ef88 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
{ jax
, jaxlib
, pkgs
}:
pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } ''
import jax
from jax import random
assert jax.devices()[0].platform == "gpu"
rng = random.PRNGKey(0)
x = random.normal(rng, (100, 100))
x @ x
print("success!")
''
|