about summary refs log tree commit diff
diff options
context:
space:
mode:
authorYaroslav Bolyukin <iam@lach.pw>2024-03-02 04:08:30 +0100
committerYaroslav Bolyukin <iam@lach.pw>2024-03-02 10:02:45 +0100
commit2eedfae46b0e632089aefad388cbc7093e7eea0b (patch)
tree67a4cdd5861c27606678eabd95bdae663b524a72
parentac43ec34d9181114a290ca734d13f16e25229d7f (diff)
torchaudio: add rocm support
-rw-r--r--pkgs/development/python-modules/torchaudio/default.nix52
1 files changed, 51 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/torchaudio/default.nix b/pkgs/development/python-modules/torchaudio/default.nix
index 73aec87cca611..f9d4945ef6c06 100644
--- a/pkgs/development/python-modules/torchaudio/default.nix
+++ b/pkgs/development/python-modules/torchaudio/default.nix
@@ -9,10 +9,46 @@
 , pybind11
 , sox
 , torch
+
 , cudaSupport ? torch.cudaSupport
 , cudaPackages
+, rocmSupport ? torch.rocmSupport
+, rocmPackages
+
+, gpuTargets ? []
 }:
 
+let
+  # TODO: Reuse one defined in torch?
+  # Some of those dependencies are probbly not required,
+  # but it breaks when the store path is different between torch and torchaudio
+  rocmtoolkit_joined = symlinkJoin {
+    name = "rocm-merged";
+
+    paths = with rocmPackages; [
+      rocm-core clr rccl miopen miopengemm rocrand rocblas
+      rocsparse hipsparse rocthrust rocprim hipcub roctracer
+      rocfft rocsolver hipfft hipsolver hipblas
+      rocminfo rocm-thunk rocm-comgr rocm-device-libs
+      rocm-runtime clr.icd hipify
+    ];
+
+    # Fix `setuptools` not being found
+    postBuild = ''
+      rm -rf $out/nix-support
+    '';
+  };
+  # Only used for ROCm
+  gpuTargetString = lib.strings.concatStringsSep ";" (
+    if gpuTargets != [ ] then
+    # If gpuTargets is specified, it always takes priority.
+      gpuTargets
+    else if rocmSupport then
+      rocmPackages.clr.gpuTargets
+    else
+      throw "No GPU targets specified"
+  );
+in
 buildPythonPackage rec {
   pname = "torchaudio";
   version = "2.2.1";
@@ -33,6 +69,11 @@ buildPythonPackage rec {
     substituteInPlace setup.py \
       --replace 'print(" --- Initializing submodules")' "return" \
       --replace "_fetch_archives(_parse_sources())" "pass"
+  ''
+  + lib.optionalString rocmSupport ''
+    # There is no .info/version-dev, only .info/version
+    substituteInPlace cmake/LoadHIP.cmake \
+      --replace "/.info/version-dev" "/.info/version"
   '';
 
   env = {
@@ -55,7 +96,11 @@ buildPythonPackage rec {
     ninja
   ] ++ lib.optionals cudaSupport [
     cudaPackages.cuda_nvcc
-  ];
+  ] ++ lib.optionals rocmSupport (with rocmPackages; [
+    clr
+    rocblas
+    hipblas
+  ]);
 
   buildInputs = [
     ffmpeg-full
@@ -73,6 +118,11 @@ buildPythonPackage rec {
   BUILD_RNNT=0;
   BUILD_CTC_DECODER=0;
 
+  preConfigure = lib.optionalString rocmSupport ''
+    export ROCM_PATH=${rocmtoolkit_joined}
+    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
+  '';
+
   dontUseCmakeConfigure = true;
 
   doCheck = false; # requires sox backend