diff options
Diffstat (limited to 'pkgs/development/python-modules/mamba-ssm/default.nix')
-rw-r--r-- | pkgs/development/python-modules/mamba-ssm/default.nix | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/mamba-ssm/default.nix b/pkgs/development/python-modules/mamba-ssm/default.nix new file mode 100644 index 000000000000..11ac68c1e19d --- /dev/null +++ b/pkgs/development/python-modules/mamba-ssm/default.nix @@ -0,0 +1,77 @@ +{ + lib, + buildPythonPackage, + fetchFromGitHub, + causal-conv1d, + einops, + ninja, + setuptools, + torch, + transformers, + triton, + cudaPackages, + rocmPackages, + config, + cudaSupport ? config.cudaSupport, + which, +}: + +buildPythonPackage rec { + pname = "mamba"; + version = "2.2.2"; + pyproject = true; + + src = fetchFromGitHub { + owner = "state-spaces"; + repo = "mamba"; + rev = "refs/tags/v${version}"; + hash = "sha256-R702JjM3AGk7upN7GkNK8u1q4ekMK9fYQkpO6Re45Ng="; + }; + + build-system = [ + ninja + setuptools + torch + ]; + + nativeBuildInputs = [ which ]; + + buildInputs = ( + lib.optionals cudaSupport ( + with cudaPackages; + [ + cuda_cudart # cuda_runtime.h, -lcudart + cuda_cccl + libcusparse # cusparse.h + libcusolver # cusolverDn.h + cuda_nvcc + libcublas + ] + ) + ); + + dependencies = [ + causal-conv1d + einops + torch + transformers + triton + ]; + + env = { + MAMBA_FORCE_BUILD = "TRUE"; + } // lib.optionalAttrs cudaSupport { CUDA_HOME = "${lib.getDev cudaPackages.cuda_nvcc}"; }; + + # pytest tests not enabled due to nvidia GPU dependency + pythonImportsCheck = [ "mamba_ssm" ]; + + meta = with lib; { + description = "Linear-Time Sequence Modeling with Selective State Spaces"; + homepage = "https://github.com/state-spaces/mamba"; + license = licenses.asl20; + maintainers = with maintainers; [ cfhammill ]; + # The package requires CUDA or ROCm, the ROCm build hasn't + # been completed or tested, so broken if not using cuda. + broken = !cudaSupport; + }; +} |