diff options
author | Yaroslav Bolyukin <iam@lach.pw> | 2024-03-02 04:08:30 +0100 |
---|---|---|
committer | Yaroslav Bolyukin <iam@lach.pw> | 2024-03-02 10:02:45 +0100 |
commit | 2eedfae46b0e632089aefad388cbc7093e7eea0b (patch) | |
tree | 67a4cdd5861c27606678eabd95bdae663b524a72 | |
parent | ac43ec34d9181114a290ca734d13f16e25229d7f (diff) |
torchaudio: add rocm support
-rw-r--r-- | pkgs/development/python-modules/torchaudio/default.nix | 52 |
1 files changed, 51 insertions, 1 deletions
diff --git a/pkgs/development/python-modules/torchaudio/default.nix b/pkgs/development/python-modules/torchaudio/default.nix index 73aec87cca611..f9d4945ef6c06 100644 --- a/pkgs/development/python-modules/torchaudio/default.nix +++ b/pkgs/development/python-modules/torchaudio/default.nix @@ -9,10 +9,46 @@ , pybind11 , sox , torch + , cudaSupport ? torch.cudaSupport , cudaPackages +, rocmSupport ? torch.rocmSupport +, rocmPackages + +, gpuTargets ? [] }: +let + # TODO: Reuse one defined in torch? + # Some of those dependencies are probbly not required, + # but it breaks when the store path is different between torch and torchaudio + rocmtoolkit_joined = symlinkJoin { + name = "rocm-merged"; + + paths = with rocmPackages; [ + rocm-core clr rccl miopen miopengemm rocrand rocblas + rocsparse hipsparse rocthrust rocprim hipcub roctracer + rocfft rocsolver hipfft hipsolver hipblas + rocminfo rocm-thunk rocm-comgr rocm-device-libs + rocm-runtime clr.icd hipify + ]; + + # Fix `setuptools` not being found + postBuild = '' + rm -rf $out/nix-support + ''; + }; + # Only used for ROCm + gpuTargetString = lib.strings.concatStringsSep ";" ( + if gpuTargets != [ ] then + # If gpuTargets is specified, it always takes priority. + gpuTargets + else if rocmSupport then + rocmPackages.clr.gpuTargets + else + throw "No GPU targets specified" + ); +in buildPythonPackage rec { pname = "torchaudio"; version = "2.2.1"; @@ -33,6 +69,11 @@ buildPythonPackage rec { substituteInPlace setup.py \ --replace 'print(" --- Initializing submodules")' "return" \ --replace "_fetch_archives(_parse_sources())" "pass" + '' + + lib.optionalString rocmSupport '' + # There is no .info/version-dev, only .info/version + substituteInPlace cmake/LoadHIP.cmake \ + --replace "/.info/version-dev" "/.info/version" ''; env = { @@ -55,7 +96,11 @@ buildPythonPackage rec { ninja ] ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc - ]; + ] ++ lib.optionals rocmSupport (with rocmPackages; [ + clr + rocblas + hipblas + ]); buildInputs = [ ffmpeg-full @@ -73,6 +118,11 @@ buildPythonPackage rec { BUILD_RNNT=0; BUILD_CTC_DECODER=0; + preConfigure = lib.optionalString rocmSupport '' + export ROCM_PATH=${rocmtoolkit_joined} + export PYTORCH_ROCM_ARCH="${gpuTargetString}" + ''; + dontUseCmakeConfigure = true; doCheck = false; # requires sox backend |