about summary refs log tree commit diff
path: root/pkgs/development/python-modules/tinygrad/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/tinygrad/default.nix')
-rw-r--r--pkgs/development/python-modules/tinygrad/default.nix47
1 files changed, 40 insertions, 7 deletions
diff --git a/pkgs/development/python-modules/tinygrad/default.nix b/pkgs/development/python-modules/tinygrad/default.nix
index 387ee633264e..ae55395b2ca2 100644
--- a/pkgs/development/python-modules/tinygrad/default.nix
+++ b/pkgs/development/python-modules/tinygrad/default.nix
@@ -9,14 +9,17 @@
   rocmSupport ? config.rocmSupport,
   cudaPackages,
   ocl-icd,
-  stdenv,
   rocmPackages,
+  stdenv,
 
   # build-system
   setuptools,
 
   # dependencies
+  llvmlite,
   numpy,
+  triton,
+  unicorn,
 
   # tests
   blobfile,
@@ -25,9 +28,9 @@
   hexdump,
   hypothesis,
   librosa,
+  networkx,
   onnx,
   pillow,
-  pydot,
   pytest-xdist,
   pytestCheckHook,
   safetensors,
@@ -36,6 +39,8 @@
   torch,
   tqdm,
   transformers,
+
+  tinygrad,
 }:
 
 buildPythonPackage rec {
@@ -67,6 +72,18 @@ buildPythonPackage rec {
       substituteInPlace tinygrad/runtime/autogen/opencl.py \
         --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
     ''
+    # Patch `clang` directly in the source file
+    + ''
+      substituteInPlace tinygrad/runtime/ops_clang.py \
+        --replace-fail "'clang'" "'${lib.getExe clang}'"
+    ''
+    # `cuda_fp16.h` and co. are needed at runtime to compile kernels
+    + lib.optionalString cudaSupport ''
+      substituteInPlace tinygrad/runtime/support/compiler_cuda.py \
+        --replace-fail \
+        ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \
+        ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"'
+    ''
     + lib.optionalString rocmSupport ''
       substituteInPlace tinygrad/runtime/autogen/hip.py \
         --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
@@ -82,12 +99,24 @@ buildPythonPackage rec {
     [
       numpy
     ]
-    ++ lib.optionals stdenv.isDarwin [
+    ++ lib.optionals stdenv.hostPlatform.isDarwin [
       # pyobjc-framework-libdispatch
       # pyobjc-framework-metal
     ];
 
-  pythonImportsCheck = [ "tinygrad" ];
+  optional-dependencies = {
+    llvm = [ llvmlite ];
+    arm = [ unicorn ];
+    triton = [ triton ];
+  };
+
+  pythonImportsCheck =
+    [
+      "tinygrad"
+    ]
+    ++ lib.optionals cudaSupport [
+      "tinygrad.runtime.ops_nv"
+    ];
 
   nativeCheckInputs = [
     blobfile
@@ -96,9 +125,9 @@ buildPythonPackage rec {
     hexdump
     hypothesis
     librosa
+    networkx
     onnx
     pillow
-    pydot
     pytest-xdist
     pytestCheckHook
     safetensors
@@ -107,7 +136,7 @@ buildPythonPackage rec {
     torch
     tqdm
     transformers
-  ];
+  ] ++ networkx.optional-dependencies.extra;
 
   preCheck = ''
     export HOME=$(mktemp -d)
@@ -170,6 +199,10 @@ buildPythonPackage rec {
     "extra/"
   ];
 
+  passthru.tests = {
+    withCuda = tinygrad.override { cudaSupport = true; };
+  };
+
   meta = {
     description = "Simple and powerful neural network framework";
     homepage = "https://github.com/tinygrad/tinygrad";
@@ -177,6 +210,6 @@ buildPythonPackage rec {
     license = lib.licenses.mit;
     maintainers = with lib.maintainers; [ GaetanLepage ];
     # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal
-    broken = stdenv.isDarwin;
+    broken = stdenv.hostPlatform.isDarwin;
   };
 }