about summary refs log tree commit diff
diff options
context:
space:
mode:
authorConnor Baker <connor.baker@tweag.io>2023-09-18 19:35:17 -0400
committerGitHub <noreply@github.com>2023-09-18 19:35:17 -0400
commit0ed41137b76534003507b55dfe214ae56c30e021 (patch)
treea2ab50414215efb86b39f30e6dde96c633151af6
parentd35ac80828cf341525a5eda63e3fc741489c5275 (diff)
parenta11c1555527f64f9ff34a0af0abae5044e47be69 (diff)
Merge pull request #255904 from ConnorBaker/fix/torch-descriptive-broken-messages
python3Packages.torch: add descriptive messages when marked broken
-rw-r--r--pkgs/development/python-modules/torch/default.nix25
1 files changed, 12 insertions, 13 deletions
diff --git a/pkgs/development/python-modules/torch/default.nix b/pkgs/development/python-modules/torch/default.nix
index 1fa790686cac1..f9f6e377b1390 100644
--- a/pkgs/development/python-modules/torch/default.nix
+++ b/pkgs/development/python-modules/torch/default.nix
@@ -51,7 +51,7 @@
 }:
 
 let
-  inherit (lib) lists strings trivial;
+  inherit (lib) attrsets lists strings trivial;
   inherit (cudaPackages) cudaFlags cudnn nccl;
 
   setBool = v: if v then "1" else "0";
@@ -105,6 +105,14 @@ let
       rocm-runtime rocm-opencl-runtime hipify
     ];
   };
+
+  brokenConditions = attrsets.filterAttrs (_: cond: cond) {
+    "CUDA and ROCm are not mutually exclusive" = cudaSupport && rocmSupport;
+    "CUDA is not targeting Linux" = cudaSupport && !stdenv.isLinux;
+    "Unsupported CUDA version" = cudaSupport && (cudaPackages.cudaMajorVersion != "11");
+    "MPI cudatoolkit does not match cudaPackages.cudatoolkit" = MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit);
+    "Magma cudaPackages does not match cudaPackages" = cudaSupport && (magma.cudaPackages != cudaPackages);
+  };
 in buildPythonPackage rec {
   pname = "torch";
   # Don't forget to update torch-bin to the same version.
@@ -426,6 +434,8 @@ in buildPythonPackage rec {
     inherit cudaSupport cudaPackages;
     # At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
     blasProvider = blas.provider;
+    # To help debug when a package is broken due to CUDA support
+    inherit brokenConditions;
   } // lib.optionalAttrs cudaSupport {
     # NOTE: supportedCudaCapabilities isn't computed unless cudaSupport is true, so we can't use
     #   it in the passthru set above because a downstream package might try to access it even
@@ -441,17 +451,6 @@ in buildPythonPackage rec {
     license = licenses.bsd3;
     maintainers = with maintainers; [ teh thoughtpolice tscholak ]; # tscholak esp. for darwin-related builds
     platforms = with platforms; linux ++ lib.optionals (!cudaSupport && !rocmSupport) darwin;
-    broken = builtins.any trivial.id [
-      # CUDA and ROCm are mutually exclusive
-      (cudaSupport && rocmSupport)
-      # CUDA is only supported on Linux
-      (cudaSupport && !stdenv.isLinux)
-      # Only CUDA 11 is currently supported
-      (cudaSupport && (cudaPackages.cudaMajorVersion != "11"))
-      # MPI cudatoolkit does not match cudaPackages.cudatoolkit
-      (MPISupport && cudaSupport && (mpi.cudatoolkit != cudaPackages.cudatoolkit))
-      # Magma cudaPackages does not match cudaPackages
-      (cudaSupport && (magma.cudaPackages != cudaPackages))
-    ];
+    broken = builtins.any trivial.id (builtins.attrValues brokenConditions);
   };
 }