diff options
Diffstat (limited to 'pkgs/development/python-modules/tinygrad/default.nix')
-rw-r--r-- | pkgs/development/python-modules/tinygrad/default.nix | 47 |
1 files changed, 40 insertions, 7 deletions
diff --git a/pkgs/development/python-modules/tinygrad/default.nix b/pkgs/development/python-modules/tinygrad/default.nix index 387ee633264e..ae55395b2ca2 100644 --- a/pkgs/development/python-modules/tinygrad/default.nix +++ b/pkgs/development/python-modules/tinygrad/default.nix @@ -9,14 +9,17 @@ rocmSupport ? config.rocmSupport, cudaPackages, ocl-icd, - stdenv, rocmPackages, + stdenv, # build-system setuptools, # dependencies + llvmlite, numpy, + triton, + unicorn, # tests blobfile, @@ -25,9 +28,9 @@ hexdump, hypothesis, librosa, + networkx, onnx, pillow, - pydot, pytest-xdist, pytestCheckHook, safetensors, @@ -36,6 +39,8 @@ torch, tqdm, transformers, + + tinygrad, }: buildPythonPackage rec { @@ -67,6 +72,18 @@ buildPythonPackage rec { substituteInPlace tinygrad/runtime/autogen/opencl.py \ --replace-fail "ctypes.util.find_library('OpenCL')" "'${ocl-icd}/lib/libOpenCL.so'" '' + # Patch `clang` directly in the source file + + '' + substituteInPlace tinygrad/runtime/ops_clang.py \ + --replace-fail "'clang'" "'${lib.getExe clang}'" + '' + # `cuda_fp16.h` and co. are needed at runtime to compile kernels + + lib.optionalString cudaSupport '' + substituteInPlace tinygrad/runtime/support/compiler_cuda.py \ + --replace-fail \ + ', "-I/usr/local/cuda/include", "-I/usr/include", "-I/opt/cuda/include/"' \ + ', "-I${lib.getDev cudaPackages.cuda_cudart}/include/"' + '' + lib.optionalString rocmSupport '' substituteInPlace tinygrad/runtime/autogen/hip.py \ --replace-fail "/opt/rocm/lib/libamdhip64.so" "${rocmPackages.clr}/lib/libamdhip64.so" \ @@ -82,12 +99,24 @@ buildPythonPackage rec { [ numpy ] - ++ lib.optionals stdenv.isDarwin [ + ++ lib.optionals stdenv.hostPlatform.isDarwin [ # pyobjc-framework-libdispatch # pyobjc-framework-metal ]; - pythonImportsCheck = [ "tinygrad" ]; + optional-dependencies = { + llvm = [ llvmlite ]; + arm = [ unicorn ]; + triton = [ triton ]; + }; + + pythonImportsCheck = + [ + "tinygrad" + ] + ++ lib.optionals cudaSupport [ + "tinygrad.runtime.ops_nv" + ]; nativeCheckInputs = [ blobfile @@ -96,9 +125,9 @@ buildPythonPackage rec { hexdump hypothesis librosa + networkx onnx pillow - pydot pytest-xdist pytestCheckHook safetensors @@ -107,7 +136,7 @@ buildPythonPackage rec { torch tqdm transformers - ]; + ] ++ networkx.optional-dependencies.extra; preCheck = '' export HOME=$(mktemp -d) @@ -170,6 +199,10 @@ buildPythonPackage rec { "extra/" ]; + passthru.tests = { + withCuda = tinygrad.override { cudaSupport = true; }; + }; + meta = { description = "Simple and powerful neural network framework"; homepage = "https://github.com/tinygrad/tinygrad"; @@ -177,6 +210,6 @@ buildPythonPackage rec { license = lib.licenses.mit; maintainers = with lib.maintainers; [ GaetanLepage ]; # Requires unpackaged pyobjc-framework-libdispatch and pyobjc-framework-metal - broken = stdenv.isDarwin; + broken = stdenv.hostPlatform.isDarwin; }; } |