diff options
Diffstat (limited to 'pkgs/development/python-modules/causal-conv1d/default.nix')
-rw-r--r-- | pkgs/development/python-modules/causal-conv1d/default.nix | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/causal-conv1d/default.nix b/pkgs/development/python-modules/causal-conv1d/default.nix new file mode 100644 index 000000000000..0653959ed6da --- /dev/null +++ b/pkgs/development/python-modules/causal-conv1d/default.nix @@ -0,0 +1,69 @@ +{ + lib, + buildPythonPackage, + fetchFromGitHub, + ninja, + setuptools, + torch, + cudaPackages, + rocmPackages, + config, + cudaSupport ? config.cudaSupport, + which, +}: + +buildPythonPackage rec { + pname = "causal-conv1d"; + version = "1.4.0"; + pyproject = true; + + src = fetchFromGitHub { + owner = "Dao-AILab"; + repo = "causal-conv1d"; + rev = "refs/tags/v${version}"; + hash = "sha256-p5x5u3zEmEMN3mWd88o3jmcpKUnovTvn7I9jIOj/ie0="; + }; + + build-system = [ + ninja + setuptools + torch + ]; + + nativeBuildInputs = [ which ]; + + buildInputs = ( + lib.optionals cudaSupport ( + with cudaPackages; + [ + cuda_cudart # cuda_runtime.h, -lcudart + cuda_cccl + libcusparse # cusparse.h + libcusolver # cusolverDn.h + cuda_nvcc + libcublas + ] + ) + ); + + dependencies = [ + torch + ]; + + # pytest tests not enabled due to nvidia GPU dependency + pythonImportsCheck = [ "causal_conv1d" ]; + + env = { + CAUSAL_CONV1D_FORCE_BUILD = "TRUE"; + } // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; + + meta = with lib; { + description = "Causal depthwise conv1d in CUDA with a PyTorch interface"; + homepage = "https://github.com/Dao-AILab/causal-conv1d"; + license = licenses.bsd3; + maintainers = with maintainers; [ cfhammill ]; + # The package requires CUDA or ROCm, the ROCm build hasn't + # been completed or tested, so broken if not using cuda. + broken = !cudaSupport; + }; +} |