about summary refs log tree commit diff
path: root/pkgs
diff options
context:
space:
mode:
authorSomeone <else@someonex.net>2024-06-26 19:16:53 +0000
committerGitHub <noreply@github.com>2024-06-26 19:16:53 +0000
commitcb69dc5b8da282b19049b98aae7bea852a4f948a (patch)
tree52ad5e0bf058548c00d2e149e80ab94819a1d3cb /pkgs
parentfba0991e11bde8394c104b22fa4bae0fe7e1151a (diff)
parent79a7186f1ce8d94b0c136a7cc7c3e2e31facc794 (diff)
Merge pull request #256230 from SomeoneSerge/feat/gpu-tests-py
GPU access in the sandbox
Diffstat (limited to 'pkgs')
-rw-r--r--pkgs/applications/misc/blender/default.nix18
-rw-r--r--pkgs/applications/misc/blender/test-cuda.py8
-rw-r--r--pkgs/by-name/ni/nix-required-mounts/closure.nix37
-rw-r--r--pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py201
-rw-r--r--pkgs/by-name/ni/nix-required-mounts/package.nix67
-rw-r--r--pkgs/by-name/ni/nix-required-mounts/pyproject.toml20
-rw-r--r--pkgs/by-name/ni/nix-required-mounts/scripts/nix_required_mounts_closure.py45
-rw-r--r--pkgs/development/cuda-modules/saxpy/default.nix11
-rw-r--r--pkgs/development/cuda-modules/write-gpu-python-test.nix29
-rw-r--r--pkgs/development/python-modules/pynvml/default.nix9
-rw-r--r--pkgs/development/python-modules/torch/bin.nix4
-rw-r--r--pkgs/development/python-modules/torch/default.nix7
-rw-r--r--pkgs/development/python-modules/torch/gpu-checks.nix40
-rw-r--r--pkgs/development/python-modules/torch/tests.nix3
-rw-r--r--pkgs/top-level/cuda-packages.nix2
15 files changed, 498 insertions, 3 deletions
diff --git a/pkgs/applications/misc/blender/default.nix b/pkgs/applications/misc/blender/default.nix
index e54cae9e56a24..3d044abaad6d9 100644
--- a/pkgs/applications/misc/blender/default.nix
+++ b/pkgs/applications/misc/blender/default.nix
@@ -7,6 +7,7 @@
   SDL,
   addOpenGLRunpath,
   alembic,
+  blender,
   boost,
   brotli,
   callPackage,
@@ -372,6 +373,20 @@ stdenv.mkDerivation (finalAttrs: {
             --render-frame 1
         done
       '';
+      tester-cudaAvailable = cudaPackages.writeGpuTestPython { } ''
+        import subprocess
+        subprocess.run([${
+          lib.concatMapStringsSep ", " (x: ''"${x}"'') [
+            (lib.getExe (blender.override { cudaSupport = true; }))
+            "--background"
+            "-noaudio"
+            "--python-exit-code"
+            "1"
+            "--python"
+            "${./test-cuda.py}"
+          ]
+        }], check=True)  # noqa: E501
+      '';
     };
   };
 
@@ -381,7 +396,8 @@ stdenv.mkDerivation (finalAttrs: {
     # They comment two licenses: GPLv2 and Blender License, but they
     # say: "We've decided to cancel the BL offering for an indefinite period."
     # OptiX, enabled with cudaSupport, is non-free.
-    license = with lib.licenses; [ gpl2Plus ] ++ lib.optional cudaSupport unfree;
+    license = with lib.licenses; [ gpl2Plus ] ++ lib.optional cudaSupport (unfree // { shortName = "NVidia OptiX EULA"; });
+
     platforms = [
       "aarch64-linux"
       "x86_64-darwin"
diff --git a/pkgs/applications/misc/blender/test-cuda.py b/pkgs/applications/misc/blender/test-cuda.py
new file mode 100644
index 0000000000000..8a3ec57347592
--- /dev/null
+++ b/pkgs/applications/misc/blender/test-cuda.py
@@ -0,0 +1,8 @@
+import bpy
+
+preferences = bpy.context.preferences.addons["cycles"].preferences
+devices = preferences.get_devices_for_type("CUDA")
+ids = [d.id for d in devices]
+
+assert any("CUDA" in i for i in ids), f"CUDA not present in {ids}"
+print("CUDA is available")
diff --git a/pkgs/by-name/ni/nix-required-mounts/closure.nix b/pkgs/by-name/ni/nix-required-mounts/closure.nix
new file mode 100644
index 0000000000000..3e361114bc4cb
--- /dev/null
+++ b/pkgs/by-name/ni/nix-required-mounts/closure.nix
@@ -0,0 +1,37 @@
+# Use exportReferencesGraph to capture the possible dependencies of the
+# drivers (e.g. libc linked through DT_RUNPATH) and ensure they are mounted
+# in the sandbox as well. In practice, things seemed to have worked without
+# this as well, but we go with the safe option until we understand why.
+
+{
+  lib,
+  runCommand,
+  python3Packages,
+  allowedPatterns,
+}:
+runCommand "allowed-patterns.json"
+  {
+    nativeBuildInputs = [ python3Packages.python ];
+    exportReferencesGraph = builtins.concatMap (
+      name:
+      builtins.concatMap (
+        path:
+        let
+          prefix = "${builtins.storeDir}/";
+          # Has to start with a letter: https://github.com/NixOS/nix/blob/516e7ddc41f39ff939b5d5b5dc71e590f24890d4/src/libstore/build/local-derivation-goal.cc#L568
+          exportName = ''references-${lib.strings.removePrefix prefix "${path}"}'';
+          isStorePath = lib.isStorePath path && (lib.hasPrefix prefix "${path}");
+        in
+        lib.optionals isStorePath [
+          exportName
+          path
+        ]
+      ) allowedPatterns.${name}.paths
+    ) (builtins.attrNames allowedPatterns);
+    env.storeDir = "${builtins.storeDir}/";
+    shallowConfig = builtins.toJSON allowedPatterns;
+    passAsFile = [ "shallowConfig" ];
+  }
+  ''
+    python ${./scripts/nix_required_mounts_closure.py}
+  ''
diff --git a/pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py b/pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py
new file mode 100644
index 0000000000000..6f05ee913a5a4
--- /dev/null
+++ b/pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py
@@ -0,0 +1,201 @@
+#!/usr/bin/env python3
+
+import glob
+import json
+import subprocess
+import textwrap
+from argparse import ArgumentParser
+from collections import deque
+from itertools import chain
+from pathlib import Path
+from typing import Deque, Dict, List, Set, Tuple, TypeAlias, TypedDict
+import logging
+
+Glob: TypeAlias = str
+PathString: TypeAlias = str
+
+
+class Mount(TypedDict):
+    host: PathString
+    guest: PathString
+
+
+class Pattern(TypedDict):
+    onFeatures: List[str]
+    paths: List[Glob | Mount]
+    unsafeFollowSymlinks: bool
+
+
+AllowedPatterns: TypeAlias = Dict[str, Pattern]
+
+
+parser = ArgumentParser("pre-build-hook")
+parser.add_argument("derivation_path")
+parser.add_argument("sandbox_path", nargs="?")
+parser.add_argument("--patterns", type=Path, required=True)
+parser.add_argument("--nix-exe", type=Path, required=True)
+parser.add_argument(
+    "--issue-command",
+    choices=("always", "conditional", "never"),
+    default="conditional",
+    help="Whether to print extra-sandbox-paths",
+)
+parser.add_argument(
+    "--issue-stop",
+    choices=("always", "conditional", "never"),
+    default="conditional",
+    help="Whether to print the final empty line",
+)
+parser.add_argument("-v", "--verbose", action="count", default=0)
+
+
+def symlink_parents(p: Path) -> List[Path]:
+    out = []
+    while p.is_symlink() and p not in out:
+        parent = p.readlink()
+        if parent.is_relative_to("."):
+            p = p / parent
+        else:
+            p = parent
+        out.append(p)
+    return out
+
+
+def get_strings(drv_env: dict, name: str) -> List[str]:
+    if "__json" in drv_env:
+        return list(json.loads(drv_env["__json"]).get(name, []))
+    else:
+        return drv_env.get(name, "").split()
+
+
+def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool]]:
+    roots = []
+    for mount in pattern["paths"]:
+        if isinstance(mount, PathString):
+            matches = glob.glob(mount)
+            assert matches, f"Specified host paths do not exist: {mount}"
+
+            roots.extend((m, m, pattern["unsafeFollowSymlinks"]) for m in matches)
+        else:
+            assert isinstance(mount, dict) and "host" in mount, mount
+            assert Path(
+                mount["host"]
+            ).exists(), f"Specified host paths do not exist: {mount['host']}"
+            roots.append(
+                (
+                    mount["guest"],
+                    mount["host"],
+                    pattern["unsafeFollowSymlinks"],
+                )
+            )
+
+    return roots
+
+
+def entrypoint():
+    args = parser.parse_args()
+
+    VERBOSITY_LEVELS = [logging.ERROR, logging.INFO, logging.DEBUG]
+
+    level_index = min(args.verbose, len(VERBOSITY_LEVELS) - 1)
+    logging.basicConfig(level=VERBOSITY_LEVELS[level_index])
+
+    drv_path = args.derivation_path
+
+    with open(args.patterns, "r") as f:
+        allowed_patterns = json.load(f)
+
+    if not Path(drv_path).exists():
+        logging.error(
+            f"{drv_path} doesn't exist."
+            " Cf. https://github.com/NixOS/nix/issues/9272"
+            " Exiting the hook",
+        )
+
+    proc = subprocess.run(
+        [
+            args.nix_exe,
+            "show-derivation",
+            drv_path,
+        ],
+        capture_output=True,
+    )
+    try:
+        parsed_drv = json.loads(proc.stdout)
+    except json.JSONDecodeError:
+        logging.error(
+            "Couldn't parse the output of"
+            "`nix show-derivation`"
+            f". Expected JSON, observed: {proc.stdout}",
+        )
+        logging.error(textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4))
+        logging.info("Exiting the nix-required-binds hook")
+        return
+    [canon_drv_path] = parsed_drv.keys()
+
+    known_features = set(
+        chain.from_iterable(
+            pattern["onFeatures"] for pattern in allowed_patterns.values()
+        )
+    )
+
+    parsed_drv = parsed_drv[canon_drv_path]
+    drv_env = parsed_drv.get("env", {})
+    required_features = get_strings(drv_env, "requiredSystemFeatures")
+    required_features = list(filter(known_features.__contains__, required_features))
+
+    patterns: List[Pattern] = list(
+        pattern
+        for pattern in allowed_patterns.values()
+        for path in pattern["paths"]
+        if any(feature in required_features for feature in pattern["onFeatures"])
+    )  # noqa: E501
+
+    queue: Deque[Tuple[PathString, PathString, bool]] = deque(
+        (mnt for pattern in patterns for mnt in validate_mounts(pattern))
+    )
+
+    unique_mounts: Set[Tuple[PathString, PathString]] = set()
+    mounts: List[Tuple[PathString, PathString]] = []
+
+    while queue:
+        guest_path_str, host_path_str, follow_symlinks = queue.popleft()
+        if (guest_path_str, host_path_str) not in unique_mounts:
+            mounts.append((guest_path_str, host_path_str))
+            unique_mounts.add((guest_path_str, host_path_str))
+
+        if not follow_symlinks:
+            continue
+
+        host_path = Path(host_path_str)
+        if not (host_path.is_dir() or host_path.is_symlink()):
+            continue
+
+        # assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
+
+        for child in host_path.iterdir() if host_path.is_dir() else [host_path]:
+            for parent in symlink_parents(child):
+                parent_str = parent.absolute().as_posix()
+                queue.append((parent_str, parent_str, follow_symlinks))
+
+    # the pre-build-hook command
+    if args.issue_command == "always" or (
+        args.issue_command == "conditional" and mounts
+    ):
+        print("extra-sandbox-paths")
+        print_paths = True
+    else:
+        print_paths = False
+
+    # arguments, one per line
+    for guest_path_str, host_path_str in mounts if print_paths else []:
+        print(f"{guest_path_str}={host_path_str}")
+
+    # terminated by an empty line
+    something_to_terminate = args.issue_stop == "conditional" and mounts
+    if args.issue_stop == "always" or something_to_terminate:
+        print()
+
+
+if __name__ == "__main__":
+    entrypoint()
diff --git a/pkgs/by-name/ni/nix-required-mounts/package.nix b/pkgs/by-name/ni/nix-required-mounts/package.nix
new file mode 100644
index 0000000000000..197e0812a8ec5
--- /dev/null
+++ b/pkgs/by-name/ni/nix-required-mounts/package.nix
@@ -0,0 +1,67 @@
+{
+  addOpenGLRunpath,
+  allowedPatternsPath ? callPackage ./closure.nix { inherit allowedPatterns; },
+  allowedPatterns ? rec {
+    # This config is just an example.
+    # When the hook observes either of the following requiredSystemFeatures:
+    nvidia-gpu.onFeatures = [
+      "gpu"
+      "nvidia-gpu"
+      "opengl"
+      "cuda"
+    ];
+    # It exposes these paths in the sandbox:
+    nvidia-gpu.paths = [
+      addOpenGLRunpath.driverLink
+      "/dev/dri"
+      "/dev/nvidia*"
+    ];
+    nvidia-gpu.unsafeFollowSymlinks = true;
+  },
+  callPackage,
+  extraWrapperArgs ? [ ],
+  lib,
+  makeWrapper,
+  nix,
+  nixosTests,
+  python3Packages,
+}:
+
+let
+  attrs = builtins.fromTOML (builtins.readFile ./pyproject.toml);
+  pname = attrs.project.name;
+  inherit (attrs.project) version;
+in
+
+python3Packages.buildPythonApplication {
+  inherit pname version;
+  pyproject = true;
+
+  src = lib.cleanSource ./.;
+
+  nativeBuildInputs = [
+    makeWrapper
+    python3Packages.setuptools
+  ];
+
+  postFixup = ''
+    wrapProgram $out/bin/${pname} \
+      --add-flags "--patterns ${allowedPatternsPath}" \
+      --add-flags "--nix-exe ${lib.getExe nix}" \
+      ${builtins.concatStringsSep " " extraWrapperArgs}
+  '';
+
+  passthru = {
+    inherit allowedPatterns;
+    tests = {
+      inherit (nixosTests) nix-required-mounts;
+    };
+  };
+  meta = {
+    inherit (attrs.project) description;
+    homepage = attrs.project.urls.Homepage;
+    license = lib.licenses.mit;
+    mainProgram = attrs.project.name;
+    maintainers = with lib.maintainers; [ SomeoneSerge ];
+  };
+}
diff --git a/pkgs/by-name/ni/nix-required-mounts/pyproject.toml b/pkgs/by-name/ni/nix-required-mounts/pyproject.toml
new file mode 100644
index 0000000000000..bb754e08ab1d3
--- /dev/null
+++ b/pkgs/by-name/ni/nix-required-mounts/pyproject.toml
@@ -0,0 +1,20 @@
+[build-system]
+build-backend = "setuptools.build_meta"
+requires = [ "setuptools" ]
+
+[project]
+name = "nix-required-mounts"
+version = "0.0.1"
+description = """
+A --pre-build-hook for Nix, \
+that allows to expose extra paths in the build sandbox \
+based on derivations' requiredSystemFeatrues"""
+
+[project.urls]
+Homepage = "https://github.com/NixOS/nixpkgs/tree/master/pkgs/by-name/ni/nix-required-mounts"
+
+[project.scripts]
+nix-required-mounts = "nix_required_mounts:entrypoint"
+
+[tool.black]
+line-length = 79
diff --git a/pkgs/by-name/ni/nix-required-mounts/scripts/nix_required_mounts_closure.py b/pkgs/by-name/ni/nix-required-mounts/scripts/nix_required_mounts_closure.py
new file mode 100644
index 0000000000000..4425e98d09251
--- /dev/null
+++ b/pkgs/by-name/ni/nix-required-mounts/scripts/nix_required_mounts_closure.py
@@ -0,0 +1,45 @@
+import json
+import os
+
+store_dir = os.environ["storeDir"]
+
+with open(os.environ["shallowConfigPath"], "r") as f:
+    config = json.load(f)
+
+cache = {}
+
+
+def read_edges(path: str | dict) -> list[str | dict]:
+    if isinstance(path, dict):
+        return [path]
+    assert isinstance(path, str)
+
+    if not path.startswith(store_dir):
+        return [path]
+    if path in cache:
+        return cache[path]
+
+    name = f"references-{path.removeprefix(store_dir)}"
+
+    assert os.path.exists(name)
+
+    with open(name, "r") as f:
+        return [p.strip() for p in f.readlines() if p.startswith(store_dir)]
+
+
+def host_path(mount: str | dict) -> str:
+    if isinstance(mount, dict):
+        return mount["host"]
+    assert isinstance(mount, str), mount
+    return mount
+
+
+for pattern in config:
+    closure = []
+    for path in config[pattern]["paths"]:
+        closure.append(path)
+        closure.extend(read_edges(path))
+    config[pattern]["paths"] = list({host_path(m): m for m in closure}.values())
+
+with open(os.environ["out"], "w") as f:
+    json.dump(config, f)
diff --git a/pkgs/development/cuda-modules/saxpy/default.nix b/pkgs/development/cuda-modules/saxpy/default.nix
index 9b7326cd321fa..5eb0a235ace81 100644
--- a/pkgs/development/cuda-modules/saxpy/default.nix
+++ b/pkgs/development/cuda-modules/saxpy/default.nix
@@ -3,6 +3,7 @@
   cmake,
   cudaPackages,
   lib,
+  saxpy,
 }:
 let
   inherit (cudaPackages)
@@ -15,7 +16,6 @@ let
     cudatoolkit
     flags
     libcublas
-    setupCudaHook
     ;
   inherit (lib) getDev getLib getOutput;
   fs = lib.fileset;
@@ -58,10 +58,19 @@ backendStdenv.mkDerivation {
     (lib.cmakeFeature "CMAKE_CUDA_ARCHITECTURES" flags.cmakeCudaArchitecturesString)
   ];
 
+  passthru.gpuCheck = saxpy.overrideAttrs (_: {
+    requiredSystemFeatures = [ "cuda" ];
+    doInstallCheck = true;
+    postInstallCheck = ''
+      $out/bin/${saxpy.meta.mainProgram or (lib.getName saxpy)}
+    '';
+  });
+
   meta = rec {
     description = "Simple (Single-precision AX Plus Y) FindCUDAToolkit.cmake example for testing cross-compilation";
     license = lib.licenses.mit;
     maintainers = lib.teams.cuda.members;
+    mainProgram = "saxpy";
     platforms = lib.platforms.unix;
     badPlatforms = lib.optionals (flags.isJetsonBuild && cudaOlder "11.4") platforms;
   };
diff --git a/pkgs/development/cuda-modules/write-gpu-python-test.nix b/pkgs/development/cuda-modules/write-gpu-python-test.nix
new file mode 100644
index 0000000000000..5f0d5c6b8fe68
--- /dev/null
+++ b/pkgs/development/cuda-modules/write-gpu-python-test.nix
@@ -0,0 +1,29 @@
+{
+  lib,
+  writers,
+  runCommand,
+}:
+{
+  feature ? "cuda",
+  name ? feature,
+  libraries ? [ ],
+}:
+content:
+
+let
+  tester = writers.writePython3Bin "tester-${name}" { inherit libraries; } content;
+  tester' = tester.overrideAttrs (oldAttrs: {
+    passthru.gpuCheck =
+      runCommand "test-${name}"
+        {
+          nativeBuildInputs = [ tester' ];
+          requiredSystemFeatures = [ feature ];
+        }
+        ''
+          set -e
+          ${tester.meta.mainProgram or (lib.getName tester')}
+          touch $out
+        '';
+  });
+in
+tester'
diff --git a/pkgs/development/python-modules/pynvml/default.nix b/pkgs/development/python-modules/pynvml/default.nix
index a115cd723998f..762771c66a2bb 100644
--- a/pkgs/development/python-modules/pynvml/default.nix
+++ b/pkgs/development/python-modules/pynvml/default.nix
@@ -1,6 +1,7 @@
 {
   lib,
   buildPythonPackage,
+  cudaPackages,
   fetchFromGitHub,
   substituteAll,
   pythonOlder,
@@ -8,6 +9,7 @@
   setuptools,
   pytestCheckHook,
   versioneer,
+  pynvml,
 }:
 
 buildPythonPackage rec {
@@ -50,6 +52,13 @@ buildPythonPackage rec {
   # OSError: /run/opengl-driver/lib/libnvidia-ml.so.1: cannot open shared object file: No such file or directory
   doCheck = false;
 
+  passthru.tests.tester-nvmlInit = cudaPackages.writeGpuTestPython { libraries = [ pynvml ]; } ''
+    import pynvml
+    from pynvml.smi import nvidia_smi  # noqa: F401
+
+    print(f"{pynvml.nvmlInit()=}")
+  '';
+
   meta = with lib; {
     description = "Python bindings for the NVIDIA Management Library";
     homepage = "https://github.com/gpuopenanalytics/pynvml";
diff --git a/pkgs/development/python-modules/torch/bin.nix b/pkgs/development/python-modules/torch/bin.nix
index f9d5cd97c183a..e2899c081e08b 100644
--- a/pkgs/development/python-modules/torch/bin.nix
+++ b/pkgs/development/python-modules/torch/bin.nix
@@ -8,6 +8,7 @@
   pythonAtLeast,
   pythonOlder,
   addOpenGLRunpath,
+  callPackage,
   cudaPackages,
   future,
   numpy,
@@ -15,6 +16,7 @@
   pyyaml,
   requests,
   setuptools,
+  torch-bin,
   typing-extensions,
   sympy,
   jinja2,
@@ -119,6 +121,8 @@ buildPythonPackage {
 
   pythonImportsCheck = [ "torch" ];
 
+  passthru.gpuChecks.cudaAvailable = callPackage ./test-cuda.nix { torch = torch-bin; };
+
   meta = {
     description = "PyTorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration";
     homepage = "https://pytorch.org/";
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index d5d7e823bed7c..9597a047bdb48 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -24,6 +24,10 @@
   mpi,
   buildDocs ? false,
 
+  # tests.cudaAvailable:
+  callPackage,
+  torchWithCuda,
+
   # Native build inputs
   cmake,
   symlinkJoin,
@@ -639,11 +643,12 @@ buildPythonPackage rec {
       rocmSupport
       rocmPackages
       ;
+    cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
     # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
     blasProvider = blas.provider;
     # To help debug when a package is broken due to CUDA support
     inherit brokenConditions;
-    cudaCapabilities = if cudaSupport then supportedCudaCapabilities else [ ];
+    tests = callPackage ./tests.nix { };
   };
 
   meta = {
diff --git a/pkgs/development/python-modules/torch/gpu-checks.nix b/pkgs/development/python-modules/torch/gpu-checks.nix
new file mode 100644
index 0000000000000..d01fffe45cb0c
--- /dev/null
+++ b/pkgs/development/python-modules/torch/gpu-checks.nix
@@ -0,0 +1,40 @@
+{
+  lib,
+  torchWithCuda,
+  torchWithRocm,
+  callPackage,
+}:
+
+let
+  accelAvailable =
+    {
+      feature,
+      versionAttr,
+      torch,
+      cudaPackages,
+    }:
+    cudaPackages.writeGpuPythonTest
+      {
+        inherit feature;
+        libraries = [ torch ];
+        name = "${feature}Available";
+      }
+      ''
+        import torch
+        message = f"{torch.cuda.is_available()=} and {torch.version.${versionAttr}=}"
+        assert torch.cuda.is_available() and torch.version.${versionAttr}, message
+        print(message)
+      '';
+in
+{
+  tester-cudaAvailable = callPackage accelAvailable {
+    feature = "cuda";
+    versionAttr = "cuda";
+    torch = torchWithCuda;
+  };
+  tester-rocmAvailable = callPackage accelAvailable {
+    feature = "rocm";
+    versionAttr = "hip";
+    torch = torchWithRocm;
+  };
+}
diff --git a/pkgs/development/python-modules/torch/tests.nix b/pkgs/development/python-modules/torch/tests.nix
new file mode 100644
index 0000000000000..5a46d0886868c
--- /dev/null
+++ b/pkgs/development/python-modules/torch/tests.nix
@@ -0,0 +1,3 @@
+{ callPackage }:
+
+callPackage ./gpu-checks.nix { }
diff --git a/pkgs/top-level/cuda-packages.nix b/pkgs/top-level/cuda-packages.nix
index d34a37294ae0a..7f01f4310c9ec 100644
--- a/pkgs/top-level/cuda-packages.nix
+++ b/pkgs/top-level/cuda-packages.nix
@@ -77,6 +77,8 @@ let
     saxpy = final.callPackage ../development/cuda-modules/saxpy { };
     nccl = final.callPackage ../development/cuda-modules/nccl { };
     nccl-tests = final.callPackage ../development/cuda-modules/nccl-tests { };
+
+    writeGpuTestPython = final.callPackage ../development/cuda-modules/write-gpu-python-test.nix { };
   });
 
   mkVersionedPackageName =