about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torchaudio/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torchaudio/default.nix')
-rw-r--r--pkgs/development/python-modules/torchaudio/default.nix152
1 files changed, 114 insertions, 38 deletions
diff --git a/pkgs/development/python-modules/torchaudio/default.nix b/pkgs/development/python-modules/torchaudio/default.nix
index 3ca33cc36b656..849335f0778ff 100644
--- a/pkgs/development/python-modules/torchaudio/default.nix
+++ b/pkgs/development/python-modules/torchaudio/default.nix
@@ -1,39 +1,100 @@
-{ lib
-, buildPythonPackage
-, fetchFromGitHub
-, cmake
-, symlinkJoin
-, ffmpeg-full
-, pkg-config
-, ninja
-, pybind11
-, sox
-, torch
-, cudaSupport ? torch.cudaSupport
-, cudaPackages
+{
+  lib,
+  buildPythonPackage,
+  fetchFromGitHub,
+  cmake,
+  symlinkJoin,
+  ffmpeg-full,
+  pkg-config,
+  ninja,
+  pybind11,
+  sox,
+  torch,
+
+  cudaSupport ? torch.cudaSupport,
+  cudaPackages,
+  rocmSupport ? torch.rocmSupport,
+  rocmPackages,
+
+  gpuTargets ? [ ],
 }:
 
+let
+  # TODO: Reuse one defined in torch?
+  # Some of those dependencies are probbly not required,
+  # but it breaks when the store path is different between torch and torchaudio
+  rocmtoolkit_joined = symlinkJoin {
+    name = "rocm-merged";
+
+    paths = with rocmPackages; [
+      rocm-core
+      clr
+      rccl
+      miopen
+      miopengemm
+      rocrand
+      rocblas
+      rocsparse
+      hipsparse
+      rocthrust
+      rocprim
+      hipcub
+      roctracer
+      rocfft
+      rocsolver
+      hipfft
+      hipsolver
+      hipblas
+      rocminfo
+      rocm-thunk
+      rocm-comgr
+      rocm-device-libs
+      rocm-runtime
+      clr.icd
+      hipify
+    ];
+
+    # Fix `setuptools` not being found
+    postBuild = ''
+      rm -rf $out/nix-support
+    '';
+  };
+  # Only used for ROCm
+  gpuTargetString = lib.strings.concatStringsSep ";" (
+    if gpuTargets != [ ] then
+      # If gpuTargets is specified, it always takes priority.
+      gpuTargets
+    else if rocmSupport then
+      rocmPackages.clr.gpuTargets
+    else
+      throw "No GPU targets specified"
+  );
+in
 buildPythonPackage rec {
   pname = "torchaudio";
-  version = "2.2.2";
+  version = "2.3.0";
   pyproject = true;
 
   src = fetchFromGitHub {
     owner = "pytorch";
     repo = "audio";
     rev = "refs/tags/v${version}";
-    hash = "sha256-rW4xLUFTpGpUeMnTBdrI/2OjgZX1ihK0EfcVK6snmpk=";
+    hash = "sha256-8EPoZ/dfxrQjdtE0rZ+2pOaXxlyhRuweYnVuA9i0Fgc=";
   };
 
-  patches = [
-    ./0001-setup.py-propagate-cmakeFlags.patch
-  ];
+  patches = [ ./0001-setup.py-propagate-cmakeFlags.patch ];
 
-  postPatch = ''
-    substituteInPlace setup.py \
-      --replace 'print(" --- Initializing submodules")' "return" \
-      --replace "_fetch_archives(_parse_sources())" "pass"
-  '';
+  postPatch =
+    ''
+      substituteInPlace setup.py \
+        --replace 'print(" --- Initializing submodules")' "return" \
+        --replace "_fetch_archives(_parse_sources())" "pass"
+    ''
+    + lib.optionalString rocmSupport ''
+      # There is no .info/version-dev, only .info/version
+      substituteInPlace cmake/LoadHIP.cmake \
+        --replace "/.info/version-dev" "/.info/version"
+    '';
 
   env = {
     TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
@@ -49,13 +110,21 @@ buildPythonPackage rec {
     ];
   };
 
-  nativeBuildInputs = [
-    cmake
-    pkg-config
-    ninja
-  ] ++ lib.optionals cudaSupport [
-    cudaPackages.cuda_nvcc
-  ];
+  nativeBuildInputs =
+    [
+      cmake
+      pkg-config
+      ninja
+    ]
+    ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]
+    ++ lib.optionals rocmSupport (
+      with rocmPackages;
+      [
+        clr
+        rocblas
+        hipblas
+      ]
+    );
 
   buildInputs = [
     ffmpeg-full
@@ -64,14 +133,17 @@ buildPythonPackage rec {
     torch.cxxdev
   ];
 
-  propagatedBuildInputs = [
-    torch
-  ];
+  propagatedBuildInputs = [ torch ];
 
-  BUILD_SOX=0;
-  BUILD_KALDI=0;
-  BUILD_RNNT=0;
-  BUILD_CTC_DECODER=0;
+  BUILD_SOX = 0;
+  BUILD_KALDI = 0;
+  BUILD_RNNT = 0;
+  BUILD_CTC_DECODER = 0;
+
+  preConfigure = lib.optionalString rocmSupport ''
+    export ROCM_PATH=${rocmtoolkit_joined}
+    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
+  '';
 
   dontUseCmakeConfigure = true;
 
@@ -82,7 +154,11 @@ buildPythonPackage rec {
     homepage = "https://pytorch.org/";
     changelog = "https://github.com/pytorch/audio/releases/tag/v${version}";
     license = licenses.bsd2;
-    platforms = platforms.unix;
+    platforms = [
+      "aarch64-darwin"
+      "aarch64-linux"
+      "x86_64-linux"
+    ];
     maintainers = with maintainers; [ junjihashimoto ];
   };
 }