about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torchaudio/default.nix
blob: 3252a5b19cfaf8f76612db80546f6c3f0db66f6e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
{
  lib,
  buildPythonPackage,
  fetchFromGitHub,
  cmake,
  symlinkJoin,
  ffmpeg-full,
  pkg-config,
  ninja,
  pybind11,
  sox,
  torch,

  cudaSupport ? torch.cudaSupport,
  cudaPackages,
  rocmSupport ? torch.rocmSupport,
  rocmPackages,

  gpuTargets ? [ ],
}:

let
  # TODO: Reuse one defined in torch?
  # Some of those dependencies are probbly not required,
  # but it breaks when the store path is different between torch and torchaudio
  rocmtoolkit_joined = symlinkJoin {
    name = "rocm-merged";

    paths = with rocmPackages; [
      rocm-core
      clr
      rccl
      miopen
      miopengemm
      rocrand
      rocblas
      rocsparse
      hipsparse
      rocthrust
      rocprim
      hipcub
      roctracer
      rocfft
      rocsolver
      hipfft
      hipsolver
      hipblas
      rocminfo
      rocm-thunk
      rocm-comgr
      rocm-device-libs
      rocm-runtime
      clr.icd
      hipify
    ];

    # Fix `setuptools` not being found
    postBuild = ''
      rm -rf $out/nix-support
    '';
  };
  # Only used for ROCm
  gpuTargetString = lib.strings.concatStringsSep ";" (
    if gpuTargets != [ ] then
      # If gpuTargets is specified, it always takes priority.
      gpuTargets
    else if rocmSupport then
      rocmPackages.clr.gpuTargets
    else
      throw "No GPU targets specified"
  );
in
buildPythonPackage rec {
  pname = "torchaudio";
  version = "2.3.1";
  pyproject = true;

  src = fetchFromGitHub {
    owner = "pytorch";
    repo = "audio";
    rev = "refs/tags/v${version}";
    hash = "sha256-PYaqRNKIhQ1DnFRZYyJJfBszVM2Bmu7A/lvvzJ6lL3g=";
  };

  patches = [ ./0001-setup.py-propagate-cmakeFlags.patch ];

  postPatch =
    ''
      substituteInPlace setup.py \
        --replace 'print(" --- Initializing submodules")' "return" \
        --replace "_fetch_archives(_parse_sources())" "pass"
    ''
    + lib.optionalString rocmSupport ''
      # There is no .info/version-dev, only .info/version
      substituteInPlace cmake/LoadHIP.cmake \
        --replace "/.info/version-dev" "/.info/version"
    '';

  env = {
    TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
  };

  # https://github.com/pytorch/audio/blob/v2.1.0/docs/source/build.linux.rst#optional-build-torchaudio-with-a-custom-built-ffmpeg
  FFMPEG_ROOT = symlinkJoin {
    name = "ffmpeg";
    paths = [
      ffmpeg-full.bin
      ffmpeg-full.dev
      ffmpeg-full.lib
    ];
  };

  nativeBuildInputs =
    [
      cmake
      pkg-config
      ninja
    ]
    ++ lib.optionals cudaSupport [ cudaPackages.cuda_nvcc ]
    ++ lib.optionals rocmSupport (
      with rocmPackages;
      [
        clr
        rocblas
        hipblas
      ]
    );

  buildInputs = [
    ffmpeg-full
    pybind11
    sox
    torch.cxxdev
  ];

  dependencies = [ torch ];

  BUILD_SOX = 0;
  BUILD_KALDI = 0;
  BUILD_RNNT = 0;
  BUILD_CTC_DECODER = 0;

  preConfigure = lib.optionalString rocmSupport ''
    export ROCM_PATH=${rocmtoolkit_joined}
    export PYTORCH_ROCM_ARCH="${gpuTargetString}"
  '';

  dontUseCmakeConfigure = true;

  doCheck = false; # requires sox backend

  meta = {
    description = "PyTorch audio library";
    homepage = "https://pytorch.org/";
    changelog = "https://github.com/pytorch/audio/releases/tag/v${version}";
    license = lib.licenses.bsd2;
    platforms = [
      "aarch64-darwin"
      "aarch64-linux"
      "x86_64-linux"
    ];
    maintainers = with lib.maintainers; [ junjihashimoto ];
  };
}