about summary refs log tree commit diff
path: root/pkgs/development/python-modules/openai-triton/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/openai-triton/default.nix')
-rw-r--r--pkgs/development/python-modules/openai-triton/default.nix212
1 files changed, 115 insertions, 97 deletions
diff --git a/pkgs/development/python-modules/openai-triton/default.nix b/pkgs/development/python-modules/openai-triton/default.nix
index 2bdb8d918af3f..1b4d713311ee1 100644
--- a/pkgs/development/python-modules/openai-triton/default.nix
+++ b/pkgs/development/python-modules/openai-triton/default.nix
@@ -1,29 +1,30 @@
-{ lib
-, config
-, buildPythonPackage
-, fetchFromGitHub
-, fetchpatch
-, addOpenGLRunpath
-, setuptools
-, pytestCheckHook
-, pythonRelaxDepsHook
-, cmake
-, ninja
-, pybind11
-, gtest
-, zlib
-, ncurses
-, libxml2
-, lit
-, llvm
-, filelock
-, torchWithRocm
-, python
-
-, runCommand
-
-, cudaPackages
-, cudaSupport ? config.cudaSupport
+{
+  lib,
+  config,
+  buildPythonPackage,
+  fetchFromGitHub,
+  fetchpatch,
+  addOpenGLRunpath,
+  setuptools,
+  pytestCheckHook,
+  pythonRelaxDepsHook,
+  cmake,
+  ninja,
+  pybind11,
+  gtest,
+  zlib,
+  ncurses,
+  libxml2,
+  lit,
+  llvm,
+  filelock,
+  torchWithRocm,
+  python,
+
+  runCommand,
+
+  cudaPackages,
+  cudaSupport ? config.cudaSupport,
 }:
 
 let
@@ -41,18 +42,20 @@ buildPythonPackage rec {
     hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
   };
 
-  patches = [
-    # fix overflow error
-    (fetchpatch {
-      url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
-      hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
-    })
-  ] ++ lib.optionals (!cudaSupport) [
-    ./0000-dont-download-ptxas.patch
-    # openai-triton wants to get ptxas version even if ptxas is not
-    # used, resulting in ptxas not found error.
-    ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
-  ];
+  patches =
+    [
+      # fix overflow error
+      (fetchpatch {
+        url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
+        hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
+      })
+    ]
+    ++ lib.optionals (!cudaSupport) [
+      ./0000-dont-download-ptxas.patch
+      # openai-triton wants to get ptxas version even if ptxas is not
+      # used, resulting in ptxas not found error.
+      ./0001-ptxas-disable-version-key-for-non-cuda-targets.patch
+    ];
 
   nativeBuildInputs = [
     setuptools
@@ -84,58 +87,67 @@ buildPythonPackage rec {
     setuptools
   ];
 
-  postPatch = let
-    # Bash was getting weird without linting,
-    # but basically upstream contains [cc, ..., "-lcuda", ...]
-    # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
-    old = [ "-lcuda" ];
-    new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cudaPackages.cuda_cudart}/lib/stubs/" ];
-
-    quote = x: ''"${x}"'';
-    oldStr = lib.concatMapStringsSep ", " quote old;
-    newStr = lib.concatMapStringsSep ", " quote new;
-  in ''
-    # Use our `cmakeFlags` instead and avoid downloading dependencies
-    substituteInPlace python/setup.py \
-      --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
-
-    # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
-    substituteInPlace bin/CMakeLists.txt \
-      --replace "add_subdirectory(FileCheck)" ""
-
-    # Don't fetch googletest
-    substituteInPlace unittest/CMakeLists.txt \
-      --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
-      --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
-  '' + lib.optionalString cudaSupport ''
-    # Use our linker flags
-    substituteInPlace python/triton/common/build.py \
-      --replace '${oldStr}' '${newStr}'
-  '';
+  postPatch =
+    let
+      # Bash was getting weird without linting,
+      # but basically upstream contains [cc, ..., "-lcuda", ...]
+      # and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
+      old = [ "-lcuda" ];
+      new = [
+        "-lcuda"
+        "-L${addOpenGLRunpath.driverLink}"
+        "-L${cudaPackages.cuda_cudart}/lib/stubs/"
+      ];
+
+      quote = x: ''"${x}"'';
+      oldStr = lib.concatMapStringsSep ", " quote old;
+      newStr = lib.concatMapStringsSep ", " quote new;
+    in
+    ''
+      # Use our `cmakeFlags` instead and avoid downloading dependencies
+      substituteInPlace python/setup.py \
+        --replace "= get_thirdparty_packages(triton_cache_path)" "= os.environ[\"cmakeFlags\"].split()"
+
+      # Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
+      substituteInPlace bin/CMakeLists.txt \
+        --replace "add_subdirectory(FileCheck)" ""
+
+      # Don't fetch googletest
+      substituteInPlace unittest/CMakeLists.txt \
+        --replace "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
+        --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
+    ''
+    + lib.optionalString cudaSupport ''
+      # Use our linker flags
+      substituteInPlace python/triton/common/build.py \
+        --replace '${oldStr}' '${newStr}'
+    '';
 
   # Avoid GLIBCXX mismatch with other cuda-enabled python packages
-  preConfigure = ''
-    # Ensure that the build process uses the requested number of cores
-    export MAX_JOBS="$NIX_BUILD_CORES"
-
-    # Upstream's setup.py tries to write cache somewhere in ~/
-    export HOME=$(mktemp -d)
-
-    # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
-    echo "
-    [build_ext]
-    base-dir=$PWD" >> python/setup.cfg
-
-    # The rest (including buildPhase) is relative to ./python/
-    cd python
-  '' + lib.optionalString cudaSupport ''
-    export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
-    export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
-
-    # Work around download_and_copy_ptxas()
-    mkdir -p $PWD/triton/third_party/cuda/bin
-    ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
-  '';
+  preConfigure =
+    ''
+      # Ensure that the build process uses the requested number of cores
+      export MAX_JOBS="$NIX_BUILD_CORES"
+
+      # Upstream's setup.py tries to write cache somewhere in ~/
+      export HOME=$(mktemp -d)
+
+      # Upstream's github actions patch setup.cfg to write base-dir. May be redundant
+      echo "
+      [build_ext]
+      base-dir=$PWD" >> python/setup.cfg
+
+      # The rest (including buildPhase) is relative to ./python/
+      cd python
+    ''
+    + lib.optionalString cudaSupport ''
+      export CC=${cudaPackages.backendStdenv.cc}/bin/cc;
+      export CXX=${cudaPackages.backendStdenv.cc}/bin/c++;
+
+      # Work around download_and_copy_ptxas()
+      mkdir -p $PWD/triton/third_party/cuda/bin
+      ln -s ${ptxas} $PWD/triton/third_party/cuda/bin
+    '';
 
   # CMake is run by setup.py instead
   dontUseCmakeConfigure = true;
@@ -168,13 +180,16 @@ buildPythonPackage rec {
     inherit torchWithRocm;
     # Implemented as alternative to pythonImportsCheck, in case if circular dependency on torch occurs again,
     # and pythonImportsCheck is commented back.
-    import-triton = runCommand "import-triton" { nativeBuildInputs = [(python.withPackages (ps: [ps.openai-triton]))]; } ''
-      python << \EOF
-      import triton
-      import triton.language
-      EOF
-      touch "$out"
-    '';
+    import-triton =
+      runCommand "import-triton"
+        { nativeBuildInputs = [ (python.withPackages (ps: [ ps.openai-triton ])) ]; }
+        ''
+          python << \EOF
+          import triton
+          import triton.language
+          EOF
+          touch "$out"
+        '';
   };
 
   pythonRemoveDeps = [
@@ -189,8 +204,11 @@ buildPythonPackage rec {
   meta = with lib; {
     description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
     homepage = "https://github.com/openai/triton";
-    platforms = lib.platforms.unix;
+    platforms = platforms.linux;
     license = licenses.mit;
-    maintainers = with maintainers; [ SomeoneSerge Madouura ];
+    maintainers = with maintainers; [
+      SomeoneSerge
+      Madouura
+    ];
   };
 }