about summary refs log tree commit diff
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2023-10-18 09:43:46 -0400
committerGitHub <noreply@github.com>2023-10-18 09:43:46 -0400
commitd26b6de226ec3248acdd930f7228f9bfd0f1e101 (patch)
tree5254740fa125f80344c99f8d0145410cd74bc23e
parent028029876bfa413723e05630cce969ad6312b3d0 (diff)
parent563f516b381479dc10de793f7050f69b2dfc1499 (diff)
Merge pull request #261654 from ConnorBaker/fix/openai-triton-arm-cuda-support
python3Packages.openai-triton: narrow the definition of broken
-rw-r--r--pkgs/development/python-modules/openai-triton/llvm.nix19
1 files changed, 16 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/openai-triton/llvm.nix b/pkgs/development/python-modules/openai-triton/llvm.nix
index 6ac0d9f5738c3..70ea69a9b15fe 100644
--- a/pkgs/development/python-modules/openai-triton/llvm.nix
+++ b/pkgs/development/python-modules/openai-triton/llvm.nix
@@ -1,4 +1,5 @@
-{ lib
+{ config
+, lib
 , stdenv
 , fetchFromGitHub
 , pkg-config
@@ -68,7 +69,17 @@ stdenv.mkDerivation (finalAttrs: {
   sourceRoot = "${finalAttrs.src.name}/llvm";
 
   cmakeFlags = [
-    "-DLLVM_TARGETS_TO_BUILD=X86;AMDGPU;NVPTX"
+    "-DLLVM_TARGETS_TO_BUILD=${
+      let
+        # Targets can be found in
+        # https://github.com/llvm/llvm-project/tree/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/lib/Basic/Targets
+        # NOTE: Unsure of how "host" would function, especially given that we might be cross-compiling.
+        llvmTargets = [ "AMDGPU" "NVPTX" ]
+        ++ lib.optionals stdenv.isAarch64 [ "AArch64" ]
+        ++ lib.optionals stdenv.isx86_64 [ "X86" ];
+      in
+      lib.concatStringsSep ";" llvmTargets
+    }"
     "-DLLVM_ENABLE_PROJECTS=llvm;mlir"
     "-DLLVM_INSTALL_UTILS=ON"
   ] ++ lib.optionals (buildDocs || buildMan) [
@@ -107,6 +118,8 @@ stdenv.mkDerivation (finalAttrs: {
     license = with licenses; [ ncsa ];
     maintainers = with maintainers; [ SomeoneSerge Madouura ];
     platforms = platforms.linux;
-    broken = stdenv.isAarch64; # https://github.com/RadeonOpenCompute/ROCm/issues/1831#issuecomment-1278205344
+    # Consider the derivation broken if we're not building for CUDA or ROCm, or if we're building for aarch64
+    # and ROCm is enabled. See https://github.com/RadeonOpenCompute/ROCm/issues/1831#issuecomment-1278205344.
+    broken = stdenv.isAarch64 && !config.cudaSupport;
   };
 })