about summary refs log tree commit diff
path: root/pkgs/development/python-modules/objax/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/objax/default.nix')
-rw-r--r--pkgs/development/python-modules/objax/default.nix55
1 files changed, 27 insertions, 28 deletions
diff --git a/pkgs/development/python-modules/objax/default.nix b/pkgs/development/python-modules/objax/default.nix
index 7f2725e9d286..63d9c1d03846 100644
--- a/pkgs/development/python-modules/objax/default.nix
+++ b/pkgs/development/python-modules/objax/default.nix
@@ -1,19 +1,19 @@
-{ lib
-, buildPythonPackage
-, fetchFromGitHub
-, fetchpatch
-, jax
-, jaxlib
-, keras
-, numpy
-, parameterized
-, pillow
-, pytestCheckHook
-, pythonOlder
-, scipy
-, setuptools
-, tensorboard
-, tensorflow
+{
+  lib,
+  buildPythonPackage,
+  fetchFromGitHub,
+  jax,
+  jaxlib,
+  keras,
+  numpy,
+  parameterized,
+  pillow,
+  pytestCheckHook,
+  pythonOlder,
+  scipy,
+  setuptools,
+  tensorboard,
+  tensorflow,
 }:
 
 buildPythonPackage rec {
@@ -30,17 +30,18 @@ buildPythonPackage rec {
     hash = "sha256-WD+pmR8cEay4iziRXqF3sHUzCMBjmLJ3wZ3iYOD+hzk=";
   };
 
-  nativeBuildInputs = [
-    setuptools
+  patches = [
+    # Issue reported upstream: https://github.com/google/objax/issues/270
+    ./replace-deprecated-device_buffers.patch
   ];
 
+  build-system = [ setuptools ];
+
   # Avoid propagating the dependency on `jaxlib`, see
   # https://github.com/NixOS/nixpkgs/issues/156767
-  buildInputs = [
-    jaxlib
-  ];
+  buildInputs = [ jaxlib ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     jax
     numpy
     parameterized
@@ -49,9 +50,7 @@ buildPythonPackage rec {
     tensorboard
   ];
 
-  pythonImportsCheck = [
-    "objax"
-  ];
+  pythonImportsCheck = [ "objax" ];
 
   # This is necessay to ignore the presence of two protobufs version (tensorflow is bringing an
   # older version).
@@ -63,9 +62,7 @@ buildPythonPackage rec {
     tensorflow
   ];
 
-  pytestFlagsArray = [
-    "tests/*.py"
-  ];
+  pytestFlagsArray = [ "tests/*.py" ];
 
   disabledTests = [
     # Test requires internet access for prefetching some weights
@@ -80,5 +77,7 @@ buildPythonPackage rec {
     changelog = "https://github.com/google/objax/releases/tag/v${version}";
     license = licenses.asl20;
     maintainers = with maintainers; [ ndl ];
+    # Tests test_syncbn_{0,1,2}d and other tests from tests/parallel.py fail
+    broken = true;
   };
 }