about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jax/test-cuda.nix
blob: 5aca523f317758eb6c848b23ec6cbc9d85c12814 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
{
  jax,
  jaxlib,
  pkgs,
}:

pkgs.writers.writePython3Bin "jax-test-cuda"
  {
    libraries = [
      jax
      jaxlib
    ];
  }
  ''
    import jax
    from jax import random

    assert jax.devices()[0].platform == "gpu"

    rng = random.PRNGKey(0)
    x = random.normal(rng, (100, 100))
    x @ x

    print("success!")
  ''