diff options
Diffstat (limited to 'pkgs/development/python-modules/torchaudio/default.nix')
-rw-r--r-- | pkgs/development/python-modules/torchaudio/default.nix | 152 |
1 files changed, 114 insertions, 38 deletions
diff --git a/pkgs/development/python-modules/torchaudio/default.nix b/pkgs/development/python-modules/torchaudio/default.nix index 3ca33cc36b656..849335f0778ff 100644 --- a/pkgs/development/python-modules/torchaudio/default.nix +++ b/pkgs/development/python-modules/torchaudio/default.nix @@ -1,39 +1,100 @@ -{ lib -, buildPythonPackage -, fetchFromGitHub -, cmake -, symlinkJoin -, ffmpeg-full -, pkg-config -, ninja -, pybind11 -, sox -, torch -, cudaSupport ? torch.cudaSupport -, cudaPackages +{ + lib, + buildPythonPackage, + fetchFromGitHub, + cmake, + symlinkJoin, + ffmpeg-full, + pkg-config, + ninja, + pybind11, + sox, + torch, + + cudaSupport ? torch.cudaSupport, + cudaPackages, + rocmSupport ? torch.rocmSupport, + rocmPackages, + + gpuTargets ? [ ], }: +let + # TODO: Reuse one defined in torch? + # Some of those dependencies are probbly not required, + # but it breaks when the store path is different between torch and torchaudio + rocmtoolkit_joined = symlinkJoin { + name = "rocm-merged"; + + paths = with rocmPackages; [ + rocm-core + clr + rccl + miopen + miopengemm + rocrand + rocblas + rocsparse + hipsparse + rocthrust + rocprim + hipcub + roctracer + rocfft + rocsolver + hipfft + hipsolver + hipblas + rocminfo + rocm-thunk + rocm-comgr + rocm-device-libs + rocm-runtime + clr.icd + hipify + ]; + + # Fix `setuptools` not being found + postBuild = '' + rm -rf $out/nix-support + ''; + }; + # Only used for ROCm + gpuTargetString = lib.strings.concatStringsSep ";" ( + if gpuTargets != [ ] then + # If gpuTargets is specified, it always takes priority. + gpuTargets + else if rocmSupport then + rocmPackages.clr.gpuTargets + else + throw "No GPU targets specified" + ); +in buildPythonPackage rec { pname = "torchaudio"; - version = "2.2.2"; + version = "2.3.0"; pyproject = true; src = fetchFromGitHub { owner = "pytorch"; repo = "audio"; rev = "refs/tags/v${version}"; - hash = "sha256-rW4xLUFTpGpUeMnTBdrI/2OjgZX1ihK0EfcVK6snmpk="; + hash = "sha256-8EPoZ/dfxrQjdtE0rZ+2pOaXxlyhRuweYnVuA9i0Fgc="; }; - patches = [ - ./0001-setup.py-propagate-cmakeFlags.patch - ]; + patches = [ ./0001-setup.py-propagate-cmakeFlags.patch ]; - postPatch = '' - substituteInPlace setup.py \ - --replace 'print(" --- Initializing submodules")' "return" \ - --replace "_fetch_archives(_parse_sources())" "pass" - ''; + postPatch = + '' + substituteInPlace setup.py \ + --replace 'print(" --- Initializing submodules")' "return" \ + --replace "_fetch_archives(_parse_sources())" "pass" + '' + + lib.optionalString rocmSupport '' + # There is no .info/version-dev, only .info/version + substituteInPlace cmake/LoadHIP.cmake \ + --replace "/.info/version-dev" "/.info/version" + ''; env = { TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}"; @@ -49,13 +110,21 @@ buildPythonPackage rec { ]; }; - nativeBuildInputs = [ - cmake - pkg-config - ninja - ] ++ lib.optionals cudaSupport [ - cudaPackages.cuda_nvcc - ]; + nativeBuildInputs = + [ + cmake + pkg-config + ninja + ] + ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ] + ++ lib.optionals rocmSupport ( + with rocmPackages; + [ + clr + rocblas + hipblas + ] + ); buildInputs = [ ffmpeg-full @@ -64,14 +133,17 @@ buildPythonPackage rec { torch.cxxdev ]; - propagatedBuildInputs = [ - torch - ]; + propagatedBuildInputs = [ torch ]; - BUILD_SOX=0; - BUILD_KALDI=0; - BUILD_RNNT=0; - BUILD_CTC_DECODER=0; + BUILD_SOX = 0; + BUILD_KALDI = 0; + BUILD_RNNT = 0; + BUILD_CTC_DECODER = 0; + + preConfigure = lib.optionalString rocmSupport '' + export ROCM_PATH=${rocmtoolkit_joined} + export PYTORCH_ROCM_ARCH="${gpuTargetString}" + ''; dontUseCmakeConfigure = true; @@ -82,7 +154,11 @@ buildPythonPackage rec { homepage = "https://pytorch.org/"; changelog = "https://github.com/pytorch/audio/releases/tag/v${version}"; license = licenses.bsd2; - platforms = platforms.unix; + platforms = [ + "aarch64-darwin" + "aarch64-linux" + "x86_64-linux" + ]; maintainers = with maintainers; [ junjihashimoto ]; }; } |