about summary refs log tree commit diff
path: root/pkgs/development/python-modules/tinygrad/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/tinygrad/default.nix')
-rw-r--r--pkgs/development/python-modules/tinygrad/default.nix160
1 files changed, 114 insertions, 46 deletions
diff --git a/pkgs/development/python-modules/tinygrad/default.nix b/pkgs/development/python-modules/tinygrad/default.nix
index 82a57f7d7f08b..760b29c1adfc5 100644
--- a/pkgs/development/python-modules/tinygrad/default.nix
+++ b/pkgs/development/python-modules/tinygrad/default.nix
@@ -1,12 +1,25 @@
 {
   lib,
+  config,
   buildPythonPackage,
   fetchFromGitHub,
+  substituteAll,
+  addDriverRunpath,
+  cudaSupport ? config.cudaSupport,
+  rocmSupport ? config.rocmSupport,
+  cudaPackages,
+  ocl-icd,
+  stdenv,
+  rocmPackages,
+  # build-system
   setuptools,
   wheel,
-  gpuctypes,
+  # dependencies
   numpy,
   tqdm,
+  # nativeCheckInputs
+  clang,
+  hexdump,
   hypothesis,
   librosa,
   onnx,
@@ -22,30 +35,67 @@
 
 buildPythonPackage rec {
   pname = "tinygrad";
-  version = "0.8.0";
+  version = "0.9.0";
   pyproject = true;
 
   src = fetchFromGitHub {
     owner = "tinygrad";
     repo = "tinygrad";
     rev = "refs/tags/v${version}";
-    hash = "sha256-QAccZ79qUbe27yUykIf22WdkxYUlOffnMlShakKfp60=";
+    hash = "sha256-opBxciETZruZjHqz/3vO7rogzjvVJKItulIiok/Zs2Y=";
   };
 
-  nativeBuildInputs = [
+  patches = [
+    (substituteAll {
+      src = ./fix-dlopen-cuda.patch;
+      inherit (addDriverRunpath) driverLink;
+      libnvrtc =
+        if cudaSupport then
+          "${lib.getLib cudaPackages.cuda_nvrtc}/lib/libnvrtc.so"
+        else
+          "Please import nixpkgs with `config.cudaSupport = true`";
+    })
+  ];
+
+  postPatch =
+    ''
+      substituteInPlace tinygrad/runtime/autogen/opencl.py \
+        --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'"
+    ''
+    # hipGetDevicePropertiesR0600 is a symbol from rocm-6. We are currently at rocm-5.
+    # We are not sure that this works. Remove when rocm gets updated to version 6.
+    + lib.optionalString rocmSupport ''
+      substituteInPlace extra/hip_gpu_driver/hip_ioctl.py \
+        --replace-fail "processor = platform.processor()" "processor = ${stdenv.hostPlatform.linuxArch}"
+      substituteInPlace tinygrad/runtime/autogen/hip.py \
+        --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \
+        --replace-fail "/opt/rocm/lib/libhiprtc.so" "${rocmPackages.clr}/lib/libhiprtc.so" \
+        --replace-fail "hipGetDevicePropertiesR0600" "hipGetDeviceProperties"
+
+      substituteInPlace tinygrad/runtime/autogen/comgr.py \
+        --replace-fail "/opt/rocm/lib/libamd_comgr.so" "${rocmPackages.rocm-comgr}/lib/libamd_comgr.so"
+    '';
+
+  build-system = [
     setuptools
     wheel
   ];
 
-  propagatedBuildInputs = [
-    gpuctypes
-    numpy
-    tqdm
-  ];
+  dependencies =
+    [
+      numpy
+      tqdm
+    ]
+    ++ lib.optionals stdenv.isDarwin [
+      # pyobjc-framework-libdispatch
+      # pyobjc-framework-metal
+    ];
 
   pythonImportsCheck = [ "tinygrad" ];
 
   nativeCheckInputs = [
+    clang
+    hexdump
     hypothesis
     librosa
     onnx
@@ -63,44 +113,60 @@ buildPythonPackage rec {
     export HOME=$(mktemp -d)
   '';
 
-  disabledTests = [
-    # Require internet access
-    "test_benchmark_openpilot_model"
-    "test_bn_alone"
-    "test_bn_linear"
-    "test_bn_mnist"
-    "test_car"
-    "test_chicken"
-    "test_chicken_bigbatch"
-    "test_conv_mnist"
-    "testCopySHMtoDefault"
-    "test_data_parallel_resnet"
-    "test_e2e_big"
-    "test_fetch_small"
-    "test_huggingface_enet_safetensors"
-    "test_linear_mnist"
-    "test_load_convnext"
-    "test_load_enet"
-    "test_load_enet_alt"
-    "test_load_llama2bfloat"
-    "test_load_resnet"
-    "test_openpilot_model"
-    "test_resnet"
-    "test_shufflenet"
-    "test_transcribe_batch12"
-    "test_transcribe_batch21"
-    "test_transcribe_file1"
-    "test_transcribe_file2"
-    "test_transcribe_long"
-    "test_transcribe_long_no_batch"
-    "test_vgg7"
-  ];
+  disabledTests =
+    [
+      # Require internet access
+      "test_benchmark_openpilot_model"
+      "test_bn_alone"
+      "test_bn_linear"
+      "test_bn_mnist"
+      "test_car"
+      "test_chicken"
+      "test_chicken_bigbatch"
+      "test_conv_mnist"
+      "testCopySHMtoDefault"
+      "test_data_parallel_resnet"
+      "test_e2e_big"
+      "test_fetch_small"
+      "test_huggingface_enet_safetensors"
+      "test_linear_mnist"
+      "test_load_convnext"
+      "test_load_enet"
+      "test_load_enet_alt"
+      "test_load_llama2bfloat"
+      "test_load_resnet"
+      "test_openpilot_model"
+      "test_resnet"
+      "test_shufflenet"
+      "test_transcribe_batch12"
+      "test_transcribe_batch21"
+      "test_transcribe_file1"
+      "test_transcribe_file2"
+      "test_transcribe_long"
+      "test_transcribe_long_no_batch"
+      "test_vgg7"
+    ]
+    # Fail on aarch64-linux with AssertionError
+    ++ lib.optionals (stdenv.hostPlatform.system == "aarch64-linux") [
+      "test_casts_to"
+      "test_casts_to"
+      "test_int8_to_uint16_negative"
+      "test_casts_to"
+      "test_casts_to"
+      "test_casts_from"
+      "test_casts_to"
+      "test_int8"
+      "test_casts_to"
+    ];
 
-  disabledTestPaths = [
-    "test/extra/test_lr_scheduler.py"
-    "test/models/test_mnist.py"
-    "test/models/test_real_world.py"
-  ];
+  disabledTestPaths =
+    [
+      # Require internet access
+      "test/models/test_mnist.py"
+      "test/models/test_real_world.py"
+      "test/testextra/test_lr_scheduler.py"
+    ]
+    ++ lib.optionals (!rocmSupport) [ "extra/hip_gpu_driver/" ];
 
   meta = with lib; {
     description = "A simple and powerful neural network framework";
@@ -108,5 +174,7 @@ buildPythonPackage rec {
     changelog = "https://github.com/tinygrad/tinygrad/releases/tag/v${version}";
     license = licenses.mit;
     maintainers = with maintainers; [ GaetanLepage ];
+    # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal
+    broken = stdenv.isDarwin;
   };
 }