1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
{ jax, jaxlib, pkgs, }: pkgs.writers.writePython3Bin "jax-test-cuda" { libraries = [ jax jaxlib ]; } '' import jax from jax import random assert jax.devices()[0].platform == "gpu" rng = random.PRNGKey(0) x = random.normal(rng, (100, 100)) x @ x print("success!") ''