about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAleksana <me@aleksana.moe>2024-05-18 13:21:13 +0800
committerGitHub <noreply@github.com>2024-05-18 13:21:13 +0800
commit51d92d050b44882acfbafd755f015e78c318203c (patch)
treea6702e0e3be477ab1254af1475de4ff796a4e1bf
parent970f689a49f098197464b053a67c753738733a8d (diff)
parent2eedfae46b0e632089aefad388cbc7093e7eea0b (diff)
Merge pull request #292750 from CertainLach/torchaudio-rocm
torchaudio: add rocm support
-rw-r--r--pkgs/development/python-modules/torch/default.nix2
-rw-r--r--pkgs/development/python-modules/torchaudio/default.nix52
2 files changed, 52 insertions, 2 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index d8d3a6532a504..452b2f8598ec2 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -495,7 +495,7 @@ in buildPythonPackage rec {
   requiredSystemFeatures = [ "big-parallel" ];
 
   passthru = {
-    inherit cudaSupport cudaPackages;
+    inherit cudaSupport cudaPackages rocmSupport rocmPackages;
     # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
     blasProvider = blas.provider;
     # To help debug when a package is broken due to CUDA support
diff --git a/pkgs/development/python-modules/torchaudio/default.nix b/pkgs/development/python-modules/torchaudio/default.nix
index ff56db53a675f..2ad66d1691a43 100644
--- a/pkgs/development/python-modules/torchaudio/default.nix
+++ b/pkgs/development/python-modules/torchaudio/default.nix
@@ -9,10 +9,46 @@
 , pybind11
 , sox
 , torch
+
 , cudaSupport ? torch.cudaSupport
 , cudaPackages
+, rocmSupport ? torch.rocmSupport
+, rocmPackages
+
+, gpuTargets ? []
 }:
 
+let
+  # TODO: Reuse one defined in torch?
+  # Some of those dependencies are probbly not required,
+  # but it breaks when the store path is different between torch and torchaudio
+  rocmtoolkit_joined = symlinkJoin {
+    name = "rocm-merged";
+
+    paths = with rocmPackages; [
+      rocm-core clr rccl miopen miopengemm rocrand rocblas
+      rocsparse hipsparse rocthrust rocprim hipcub roctracer
+      rocfft rocsolver hipfft hipsolver hipblas
+      rocminfo rocm-thunk rocm-comgr rocm-device-libs
+      rocm-runtime clr.icd hipify
+    ];
+
+    # Fix `setuptools` not being found
+    postBuild = ''
+      rm -rf $out/nix-support
+    '';
+  };
+  # Only used for ROCm
+  gpuTargetString = lib.strings.concatStringsSep ";" (
+    if gpuTargets != [ ] then
+    # If gpuTargets is specified, it always takes priority.
+      gpuTargets
+    else if rocmSupport then
+      rocmPackages.clr.gpuTargets
+    else
+      throw "No GPU targets specified"
+  );
+in
 buildPythonPackage rec {
   pname = "torchaudio";
   version = "2.3.0";
@@ -33,6 +69,11 @@ buildPythonPackage rec {
     substituteInPlace setup.py \
       --replace 'print(" --- Initializing submodules")' "return" \
       --replace "_fetch_archives(_parse_sources())" "pass"
+  ''
+  + lib.optionalString rocmSupport ''
+    # There is no .info/version-dev, only .info/version
+    substituteInPlace cmake/LoadHIP.cmake \
+      --replace "/.info/version-dev" "/.info/version"
   '';
 
   env = {
@@ -55,7 +96,11 @@ buildPythonPackage rec {
     ninja
   ] ++ lib.optionals cudaSupport [
     cudaPackages.cuda_nvcc
-  ];
+  ] ++ lib.optionals rocmSupport (with rocmPackages; [
+    clr
+    rocblas
+    hipblas
+  ]);
 
   buildInputs = [
     ffmpeg-full
@@ -73,6 +118,11 @@ buildPythonPackage rec {
   BUILD_RNNT=0;
   BUILD_CTC_DECODER=0;
 
+  preConfigure = lib.optionalString rocmSupport ''
+    export ROCM_PATH=${rocmtoolkit_joined}
+    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
+  '';
+
   dontUseCmakeConfigure = true;
 
   doCheck = false; # requires sox backend