diff options
Diffstat (limited to 'pkgs/development/python-modules/xformers/default.nix')
-rw-r--r-- | pkgs/development/python-modules/xformers/default.nix | 109 |
1 files changed, 61 insertions, 48 deletions
diff --git a/pkgs/development/python-modules/xformers/default.nix b/pkgs/development/python-modules/xformers/default.nix index e0e6e9569ef3..8790b380b769 100644 --- a/pkgs/development/python-modules/xformers/default.nix +++ b/pkgs/development/python-modules/xformers/default.nix @@ -1,77 +1,90 @@ -{ lib -, buildPythonPackage -, pythonOlder -, fetchFromGitHub -, which -# runtime dependencies -, numpy -, torch -# check dependencies -, pytestCheckHook -, pytest-cov -# , pytest-mpi -, pytest-timeout -# , pytorch-image-models -, hydra-core -, fairscale -, scipy -, cmake -, openai-triton -, networkx -#, apex -, einops -, transformers -, timm +{ + lib, + stdenv, + buildPythonPackage, + pythonOlder, + fetchFromGitHub, + which, + setuptools, + # runtime dependencies + numpy, + torch, + # check dependencies + pytestCheckHook, + pytest-cov-stub, + # , pytest-mpi + pytest-timeout, + # , pytorch-image-models + hydra-core, + fairscale, + scipy, + cmake, + ninja, + triton, + networkx, + #, apex + einops, + transformers, + timm, #, flash-attn }: let inherit (torch) cudaCapabilities cudaPackages cudaSupport; - version = "0.0.23.post1"; + version = "0.0.28.post3"; in buildPythonPackage { pname = "xformers"; inherit version; - format = "setuptools"; + pyproject = true; - disabled = pythonOlder "3.7"; + disabled = pythonOlder "3.9"; src = fetchFromGitHub { owner = "facebookresearch"; repo = "xformers"; rev = "refs/tags/v${version}"; - hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY="; + hash = "sha256-23tnhCHK+Z0No8fqZxkgDFp2VIgXZR4jpM+pkb/vvmw="; fetchSubmodules = true; }; - patches = [ - ./0001-fix-allow-building-without-git.patch - ]; + patches = [ ./0001-fix-allow-building-without-git.patch ]; + + build-system = [ setuptools ]; preBuild = '' cat << EOF > ./xformers/version.py # noqa: C801 __version__ = "${version}" EOF - '' + lib.optionalString cudaSupport '' - export CUDA_HOME=${cudaPackages.cuda_nvcc} - export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}" + + export MAX_JOBS=$NIX_BUILD_CORES ''; - buildInputs = lib.optionals cudaSupport (with cudaPackages; [ - # flash-attn build - cuda_cudart # cuda_runtime_api.h - libcusparse.dev # cusparse.h - cuda_cccl.dev # nv/target - libcublas.dev # cublas_v2.h - libcusolver.dev # cusolverDn.h - libcurand.dev # curand_kernel.h - ]); + env = lib.attrsets.optionalAttrs cudaSupport { + TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}"; + }; + + stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv; + + buildInputs = lib.optionals cudaSupport ( + with cudaPackages; + [ + # flash-attn build + cuda_cudart # cuda_runtime_api.h + libcusparse # cusparse.h + cuda_cccl # nv/target + libcublas # cublas_v2.h + libcusolver # cusolverDn.h + libcurand # curand_kernel.h + ] + ); nativeBuildInputs = [ + ninja which - ]; + ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]); - propagatedBuildInputs = [ + dependencies = [ numpy torch ]; @@ -89,14 +102,14 @@ buildPythonPackage { nativeCheckInputs = [ pytestCheckHook - pytest-cov + pytest-cov-stub pytest-timeout hydra-core fairscale scipy cmake networkx - openai-triton + triton # apex einops transformers @@ -105,7 +118,7 @@ buildPythonPackage { ]; meta = with lib; { - description = "XFormers: A collection of composable Transformer building blocks"; + description = "Collection of composable Transformer building blocks"; homepage = "https://github.com/facebookresearch/xformers"; changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md"; license = licenses.bsd3; |