about summary refs log tree commit diff
path: root/pkgs/development/cuda-modules/cuda/overrides.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/cuda-modules/cuda/overrides.nix')
-rw-r--r--pkgs/development/cuda-modules/cuda/overrides.nix45
1 files changed, 24 insertions, 21 deletions
diff --git a/pkgs/development/cuda-modules/cuda/overrides.nix b/pkgs/development/cuda-modules/cuda/overrides.nix
index 5d23d8f7f2a1a..46e4401f6a26e 100644
--- a/pkgs/development/cuda-modules/cuda/overrides.nix
+++ b/pkgs/development/cuda-modules/cuda/overrides.nix
@@ -44,7 +44,7 @@ filterAndCreateOverrides {
     }:
     prevAttrs: {
       buildInputs = prevAttrs.buildInputs ++ [
-        libcublas.lib
+        libcublas
         numactl
         rdma-core
       ];
@@ -66,17 +66,17 @@ filterAndCreateOverrides {
       buildInputs =
         prevAttrs.buildInputs
         # Always depends on this
-        ++ [ libcublas.lib ]
+        ++ [ libcublas ]
         # Dependency from 12.0 and on
-        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink.lib ]
+        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink ]
         # Dependency from 12.1 and on
-        ++ lib.lists.optionals (cudaAtLeast "12.1") [ libcusparse.lib ];
+        ++ lib.lists.optionals (cudaAtLeast "12.1") [ libcusparse ];
 
       brokenConditions = prevAttrs.brokenConditions // {
         "libnvjitlink missing (CUDA >= 12.0)" =
-          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink.lib != null));
+          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink != null));
         "libcusparse missing (CUDA >= 12.1)" =
-          !(cudaAtLeast "12.1" -> (libcusparse != null && libcusparse.lib != null));
+          !(cudaAtLeast "12.1" -> (libcusparse != null && libcusparse != null));
       };
     };
 
@@ -90,16 +90,16 @@ filterAndCreateOverrides {
       buildInputs =
         prevAttrs.buildInputs
         # Dependency from 12.0 and on
-        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink.lib ];
+        ++ lib.lists.optionals (cudaAtLeast "12.0") [ libnvjitlink ];
 
       brokenConditions = prevAttrs.brokenConditions // {
         "libnvjitlink missing (CUDA >= 12.0)" =
-          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink.lib != null));
+          !(cudaAtLeast "12.0" -> (libnvjitlink != null && libnvjitlink != null));
       };
     };
 
   # TODO(@connorbaker): cuda_cudart.dev depends on crt/host_config.h, which is from
-  # cuda_nvcc.dev. It would be nice to be able to encode that.
+  # (getDev cuda_nvcc). It would be nice to be able to encode that.
   cuda_cudart =
     { addDriverRunpath, lib }:
     prevAttrs: {
@@ -166,13 +166,17 @@ filterAndCreateOverrides {
     };
 
   cuda_nvcc =
-    {
-      backendStdenv,
-      cuda_cudart,
-      lib,
-      setupCudaHook,
-    }:
+    { backendStdenv, setupCudaHook }:
     prevAttrs: {
+      # Merge "bin" and "dev" into "out" to avoid circular references
+      outputs = builtins.filter (
+        x:
+        !(builtins.elem x [
+          "dev"
+          "bin"
+        ])
+      ) prevAttrs.outputs;
+
       # Patch the nvcc.profile.
       # Syntax:
       # - `=` for assignment,
@@ -230,12 +234,11 @@ filterAndCreateOverrides {
       };
     };
 
-  cuda_nvprof =
-    { cuda_cupti }: prevAttrs: { buildInputs = prevAttrs.buildInputs ++ [ cuda_cupti.lib ]; };
+  cuda_nvprof = { cuda_cupti }: prevAttrs: { buildInputs = prevAttrs.buildInputs ++ [ cuda_cupti ]; };
 
   cuda_demo_suite =
     {
-      freeglut,
+      libglut,
       libcufft,
       libcurand,
       libGLU,
@@ -244,9 +247,9 @@ filterAndCreateOverrides {
     }:
     prevAttrs: {
       buildInputs = prevAttrs.buildInputs ++ [
-        freeglut
-        libcufft.lib
-        libcurand.lib
+        libglut
+        libcufft
+        libcurand
         libGLU
         libglvnd
         mesa