blob: 55a4b45f715229786731c0919af1b6e718660239 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
{
lib,
torchWithCuda,
torchWithRocm,
callPackage,
}:
let
accelAvailable =
{
feature,
versionAttr,
torch,
cudaPackages,
}:
cudaPackages.writeGpuTestPython
{
inherit feature;
libraries = [ torch ];
name = "${feature}Available";
}
''
import torch
message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
assert torch.cuda.is_available() and torch.version.${versionAttr}, message
print(message)
'';
in
{
tester-cudaAvailable = callPackage accelAvailable {
feature = "cuda";
versionAttr = "cuda";
torch = torchWithCuda;
};
tester-rocmAvailable = callPackage accelAvailable {
feature = "rocm";
versionAttr = "hip";
torch = torchWithRocm;
};
}
|