about summary refs log tree commit diff
path: root/pkgs/development/python-modules/tensorflow/bin.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/tensorflow/bin.nix')
-rw-r--r--pkgs/development/python-modules/tensorflow/bin.nix118
1 files changed, 65 insertions, 53 deletions
diff --git a/pkgs/development/python-modules/tensorflow/bin.nix b/pkgs/development/python-modules/tensorflow/bin.nix
index 1040023619262..0f54cade00cd6 100644
--- a/pkgs/development/python-modules/tensorflow/bin.nix
+++ b/pkgs/development/python-modules/tensorflow/bin.nix
@@ -1,37 +1,42 @@
-{ stdenv
-, lib
-, fetchurl
-, buildPythonPackage
-, isPy3k, pythonOlder, pythonAtLeast, astor
-, gast
-, google-pasta
-, wrapt
-, numpy
-, six
-, termcolor
-, packaging
-, protobuf
-, absl-py
-, grpcio
-, mock
-, scipy
-, wheel
-, jax
-, opt-einsum
-, tensorflow-estimator-bin
-, tensorboard
-, config
-, cudaSupport ? config.cudaSupport
-, cudaPackagesGoogle
-, zlib
-, python
-, keras-applications
-, keras-preprocessing
-, addOpenGLRunpath
-, astunparse
-, flatbuffers
-, h5py
-, typing-extensions
+{
+  stdenv,
+  lib,
+  fetchurl,
+  buildPythonPackage,
+  isPy3k,
+  pythonOlder,
+  pythonAtLeast,
+  astor,
+  gast,
+  google-pasta,
+  wrapt,
+  numpy,
+  six,
+  termcolor,
+  packaging,
+  protobuf,
+  absl-py,
+  grpcio,
+  mock,
+  scipy,
+  wheel,
+  jax,
+  opt-einsum,
+  tensorflow-estimator-bin,
+  tensorboard,
+  config,
+  cudaSupport ? config.cudaSupport,
+  cudaPackages,
+  zlib,
+  python,
+  keras-applications,
+  keras-preprocessing,
+  addOpenGLRunpath,
+  astunparse,
+  flatbuffers,
+  h5py,
+  llvmPackages,
+  typing-extensions,
 }:
 
 # We keep this binary build for two reasons:
@@ -39,24 +44,29 @@
 # - the source build is currently brittle and not easy to maintain
 
 # unsupported combination
-assert ! (stdenv.isDarwin && cudaSupport);
+assert !(stdenv.isDarwin && cudaSupport);
 
 let
   packages = import ./binary-hashes.nix;
-  inherit (cudaPackagesGoogle) cudatoolkit cudnn;
-in buildPythonPackage {
+  inherit (cudaPackages) cudatoolkit cudnn;
+in
+buildPythonPackage {
   pname = "tensorflow" + lib.optionalString cudaSupport "-gpu";
   inherit (packages) version;
   format = "wheel";
 
-  src = let
-    pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
-    platform = if stdenv.isDarwin then "mac" else "linux";
-    unit = if cudaSupport then "gpu" else "cpu";
-    key = "${platform}_py_${pyVerNoDot}_${unit}";
-  in fetchurl (packages.${key} or {});
+  src =
+    let
+      pyVerNoDot = lib.strings.stringAsChars (x: lib.optionalString (x != ".") x) python.pythonVersion;
+      platform = stdenv.system;
+      cuda = lib.optionalString cudaSupport "_gpu";
+      key = "${platform}_${pyVerNoDot}${cuda}";
+    in
+    fetchurl (packages.${key} or (throw "tensoflow-bin: unsupported system: ${stdenv.system}"));
+
+  buildInputs = [ llvmPackages.openmp ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     astunparse
     flatbuffers
     typing-extensions
@@ -81,7 +91,7 @@ in buildPythonPackage {
     h5py
   ] ++ lib.optional (!isPy3k) mock;
 
-  nativeBuildInputs = [ wheel ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
+  build-system = [ wheel ] ++ lib.optionals cudaSupport [ addOpenGLRunpath ];
 
   preConfigure = ''
     unset SOURCE_DATE_EPOCH
@@ -91,7 +101,6 @@ in buildPythonPackage {
 
     pushd dist
 
-    orig_name="$(echo ./*.whl)"
     wheel unpack --dest unpacked ./*.whl
     rm ./*.whl
     (
@@ -113,7 +122,6 @@ in buildPythonPackage {
         -e "s/Requires-Dist: numpy (.*)/Requires-Dist: numpy/"
     )
     wheel pack ./unpacked/tensorflow*
-    mv *.whl $orig_name # avoid changes to the _os_arch.whl suffix
 
     popd
   '';
@@ -168,6 +176,7 @@ in buildPythonPackage {
         "$out/${python.sitePackages}/tensorflow/python/saved_model"
         "$out/${python.sitePackages}/tensorflow/python/util"
         "$out/${python.sitePackages}/tensorflow/tsl/python/lib/core"
+        "$out/${python.sitePackages}/tensorflow.libs/"
         "${rpath}"
       )
 
@@ -199,16 +208,19 @@ in buildPythonPackage {
     "tensorflow.python.framework"
   ];
 
-  passthru = {
-    cudaPackages = cudaPackagesGoogle;
-  };
-
   meta = with lib; {
     description = "Computation using data flow graphs for scalable machine learning";
     homepage = "http://tensorflow.org";
     sourceProvenance = with sourceTypes; [ binaryNativeCode ];
     license = licenses.asl20;
-    maintainers = with maintainers; [ jyp abbradar ];
-    platforms = [ "x86_64-linux" "x86_64-darwin" ];
+    maintainers = with maintainers; [
+      jyp
+      abbradar
+    ];
+    platforms = platforms.all;
+    # Cannot import tensortfow on python 3.12 as it still dependends on distutils:
+    # ModuleNotFoundError: No module named 'distutils'
+    # https://github.com/tensorflow/tensorflow/issues/58073
+    broken = pythonAtLeast "3.12";
   };
 }