about summary refs log tree commit diff
path: root/pkgs/development/python-modules/xformers/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/xformers/default.nix')
-rw-r--r--pkgs/development/python-modules/xformers/default.nix109
1 files changed, 61 insertions, 48 deletions
diff --git a/pkgs/development/python-modules/xformers/default.nix b/pkgs/development/python-modules/xformers/default.nix
index e0e6e9569ef3..8790b380b769 100644
--- a/pkgs/development/python-modules/xformers/default.nix
+++ b/pkgs/development/python-modules/xformers/default.nix
@@ -1,77 +1,90 @@
-{ lib
-, buildPythonPackage
-, pythonOlder
-, fetchFromGitHub
-, which
-# runtime dependencies
-, numpy
-, torch
-# check dependencies
-, pytestCheckHook
-, pytest-cov
-# , pytest-mpi
-, pytest-timeout
-# , pytorch-image-models
-, hydra-core
-, fairscale
-, scipy
-, cmake
-, openai-triton
-, networkx
-#, apex
-, einops
-, transformers
-, timm
+{
+  lib,
+  stdenv,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  which,
+  setuptools,
+  # runtime dependencies
+  numpy,
+  torch,
+  # check dependencies
+  pytestCheckHook,
+  pytest-cov-stub,
+  # , pytest-mpi
+  pytest-timeout,
+  # , pytorch-image-models
+  hydra-core,
+  fairscale,
+  scipy,
+  cmake,
+  ninja,
+  triton,
+  networkx,
+  #, apex
+  einops,
+  transformers,
+  timm,
 #, flash-attn
 }:
 let
   inherit (torch) cudaCapabilities cudaPackages cudaSupport;
-  version = "0.0.23.post1";
+  version = "0.0.28.post3";
 in
 buildPythonPackage {
   pname = "xformers";
   inherit version;
-  format = "setuptools";
+  pyproject = true;
 
-  disabled = pythonOlder "3.7";
+  disabled = pythonOlder "3.9";
 
   src = fetchFromGitHub {
     owner = "facebookresearch";
     repo = "xformers";
     rev = "refs/tags/v${version}";
-    hash = "sha256-AJXow8MmX4GxtEE2jJJ/ZIBr+3i+uS4cA6vofb390rY=";
+    hash = "sha256-23tnhCHK+Z0No8fqZxkgDFp2VIgXZR4jpM+pkb/vvmw=";
     fetchSubmodules = true;
   };
 
-  patches = [
-    ./0001-fix-allow-building-without-git.patch
-  ];
+  patches = [ ./0001-fix-allow-building-without-git.patch ];
+
+  build-system = [ setuptools ];
 
   preBuild = ''
     cat << EOF > ./xformers/version.py
     # noqa: C801
     __version__ = "${version}"
     EOF
-  '' + lib.optionalString cudaSupport ''
-    export CUDA_HOME=${cudaPackages.cuda_nvcc}
-    export TORCH_CUDA_ARCH_LIST="${lib.concatStringsSep ";" cudaCapabilities}"
+
+    export MAX_JOBS=$NIX_BUILD_CORES
   '';
 
-  buildInputs = lib.optionals cudaSupport (with cudaPackages; [
-    # flash-attn build
-    cuda_cudart # cuda_runtime_api.h
-    libcusparse.dev # cusparse.h
-    cuda_cccl.dev # nv/target
-    libcublas.dev # cublas_v2.h
-    libcusolver.dev # cusolverDn.h
-    libcurand.dev # curand_kernel.h
-  ]);
+  env = lib.attrsets.optionalAttrs cudaSupport {
+    TORCH_CUDA_ARCH_LIST = "${lib.concatStringsSep ";" torch.cudaCapabilities}";
+  };
+
+  stdenv = if cudaSupport then cudaPackages.backendStdenv else stdenv;
+
+  buildInputs = lib.optionals cudaSupport (
+    with cudaPackages;
+    [
+      # flash-attn build
+      cuda_cudart # cuda_runtime_api.h
+      libcusparse # cusparse.h
+      cuda_cccl # nv/target
+      libcublas # cublas_v2.h
+      libcusolver # cusolverDn.h
+      libcurand # curand_kernel.h
+    ]
+  );
 
   nativeBuildInputs = [
+    ninja
     which
-  ];
+  ] ++ lib.optionals cudaSupport (with cudaPackages; [ cuda_nvcc ]);
 
-  propagatedBuildInputs = [
+  dependencies = [
     numpy
     torch
   ];
@@ -89,14 +102,14 @@ buildPythonPackage {
 
   nativeCheckInputs = [
     pytestCheckHook
-    pytest-cov
+    pytest-cov-stub
     pytest-timeout
     hydra-core
     fairscale
     scipy
     cmake
     networkx
-    openai-triton
+    triton
     # apex
     einops
     transformers
@@ -105,7 +118,7 @@ buildPythonPackage {
   ];
 
   meta = with lib; {
-    description = "XFormers: A collection of composable Transformer building blocks";
+    description = "Collection of composable Transformer building blocks";
     homepage = "https://github.com/facebookresearch/xformers";
     changelog = "https://github.com/facebookresearch/xformers/blob/${version}/CHANGELOG.md";
     license = licenses.bsd3;