about summary refs log tree commit diff
path: root/pkgs/development/python-modules/jaxlib
diff options
context:
space:
mode:
authorSamuel Ainsworth <skainsworth@gmail.com>2021-08-22 20:37:42 +0000
committerSamuel Ainsworth <skainsworth@gmail.com>2021-08-22 20:37:42 +0000
commit1f8686373abe21bf6b1ce972f0ea55405b449329 (patch)
tree1ae89cc7cff818bb95d75fc1f5a8612eb78b2d4e /pkgs/development/python-modules/jaxlib
parent00ca3a1fda61078c7ee1239f5f46904a0353e9b3 (diff)
python3Packages.jaxlib: init at 0.1.70
Diffstat (limited to 'pkgs/development/python-modules/jaxlib')
-rw-r--r--pkgs/development/python-modules/jaxlib/default.nix45
1 files changed, 45 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix
new file mode 100644
index 0000000000000..240c5a7d6d0ee
--- /dev/null
+++ b/pkgs/development/python-modules/jaxlib/default.nix
@@ -0,0 +1,45 @@
+# For the moment we only support the CPU backend of jaxlib. GPU and TPU backends
+# require some additional work. Their wheels are not located on PyPI.
+#  * CPU/GPU: https://storage.googleapis.com/jax-releases/jax_releases.html
+#  * TPU: https://storage.googleapis.com/jax-releases/libtpu_releases.html
+
+{ autoPatchelfHook, buildPythonPackage, fetchPypi, isPy39, lib, stdenv
+# propagatedBuildInputs
+, absl-py, flatbuffers, scipy
+}:
+
+buildPythonPackage rec {
+  pname = "jaxlib";
+  version = "0.1.70";
+  format = "wheel";
+
+  # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting
+  # all of them is a pain, so we focus on 3.9, the current nixpkgs python3
+  # version.
+  disabled = !isPy39;
+
+  src = fetchPypi {
+    inherit pname version format;
+    dist = "cp39";
+    python = "cp39";
+    platform = "manylinux2010_x86_64";
+    sha256 = "sha256-mytMTqoavpuRawj52MU5/iFj27SGlm8DaoQ5vd/3bss=";
+  };
+
+  # Prebuilt wheels are dynamically linked against things that nix can't find.
+  # Run `autoPatchelfHook` to automagically fix them.
+  nativeBuildInputs = [ autoPatchelfHook ];
+  # Dynamic link dependencies
+  buildInputs = [ stdenv.cc.cc ];
+
+  # pip dependencies
+  propagatedBuildInputs = [ absl-py flatbuffers scipy ];
+
+  meta = with lib; {
+    description = "XLA library for JAX";
+    homepage    = "https://github.com/google/jax";
+    license     = licenses.asl20;
+    maintainers = with maintainers; [ samuela ];
+    platforms = [ "x86_64-linux" ];
+  };
+}