about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torchaudio/default.nix
blob: 2ad66d1691a431fac9a3f772eb38c0b98c2480af (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
{ lib
, buildPythonPackage
, fetchFromGitHub
, cmake
, symlinkJoin
, ffmpeg-full
, pkg-config
, ninja
, pybind11
, sox
, torch

, cudaSupport ? torch.cudaSupport
, cudaPackages
, rocmSupport ? torch.rocmSupport
, rocmPackages

, gpuTargets ? []
}:

let
  # TODO: Reuse one defined in torch?
  # Some of those dependencies are probbly not required,
  # but it breaks when the store path is different between torch and torchaudio
  rocmtoolkit_joined = symlinkJoin {
    name = "rocm-merged";

    paths = with rocmPackages; [
      rocm-core clr rccl miopen miopengemm rocrand rocblas
      rocsparse hipsparse rocthrust rocprim hipcub roctracer
      rocfft rocsolver hipfft hipsolver hipblas
      rocminfo rocm-thunk rocm-comgr rocm-device-libs
      rocm-runtime clr.icd hipify
    ];

    # Fix `setuptools` not being found
    postBuild = ''
      rm -rf $out/nix-support
    '';
  };
  # Only used for ROCm
  gpuTargetString = lib.strings.concatStringsSep ";" (
    if gpuTargets != [ ] then
    # If gpuTargets is specified, it always takes priority.
      gpuTargets
    else if rocmSupport then
      rocmPackages.clr.gpuTargets
    else
      throw "No GPU targets specified"
  );
in
buildPythonPackage rec {
  pname = "torchaudio";
  version = "2.3.0";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "pytorch";
    repo = "audio";
    rev = "refs/tags/v${version}";
    hash = "sha256-8EPoZ/dfxrQjdtE0rZ+2pOaXxlyhRuweYnVuA9i0Fgc=";
  };

  patches = [
    ./0001-setup.py-propagate-cmakeFlags.patch
  ];

  postPatch = ''
    substituteInPlace setup.py \
      --replace 'print(" --- Initializing submodules")' "return" \
      --replace "_fetch_archives(_parse_sources())" "pass"
  ''
  + lib.optionalString rocmSupport ''
    # There is no .info/version-dev, only .info/version
    substituteInPlace cmake/LoadHIP.cmake \
      --replace "/.info/version-dev" "/.info/version"
  '';

  env = {
    TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
  };

  # https://github.com/pytorch/audio/blob/v2.1.0/docs/source/build.linux.rst#optional-build-torchaudio-with-a-custom-built-ffmpeg
  FFMPEG_ROOT = symlinkJoin {
    name = "ffmpeg";
    paths = [
      ffmpeg-full.bin
      ffmpeg-full.dev
      ffmpeg-full.lib
    ];
  };

  nativeBuildInputs = [
    cmake
    pkg-config
    ninja
  ] ++ lib.optionals cudaSupport [
    cudaPackages.cuda_nvcc
  ] ++ lib.optionals rocmSupport (with rocmPackages; [
    clr
    rocblas
    hipblas
  ]);

  buildInputs = [
    ffmpeg-full
    pybind11
    sox
    torch.cxxdev
  ];

  propagatedBuildInputs = [
    torch
  ];

  BUILD_SOX=0;
  BUILD_KALDI=0;
  BUILD_RNNT=0;
  BUILD_CTC_DECODER=0;

  preConfigure = lib.optionalString rocmSupport ''
    export ROCM_PATH=${rocmtoolkit_joined}
    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
  '';

  dontUseCmakeConfigure = true;

  doCheck = false; # requires sox backend

  meta = with lib; {
    description = "PyTorch audio library";
    homepage = "https://pytorch.org/";
    changelog = "https://github.com/pytorch/audio/releases/tag/v${version}";
    license = licenses.bsd2;
    platforms = [ "aarch64-darwin" "aarch64-linux" "x86_64-linux" ];
    maintainers = with maintainers; [ junjihashimoto ];
  };
}