about summary refs log tree commit diff
path: root/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch')
-rw-r--r--pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch14
1 files changed, 14 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch b/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
new file mode 100644
index 0000000000000..fc0fd50a90ce8
--- /dev/null
+++ b/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
@@ -0,0 +1,14 @@
+diff --git a/objax/util/util.py b/objax/util/util.py
+index c31a356..344cf9a 100644
+--- a/objax/util/util.py
++++ b/objax/util/util.py
+@@ -117,7 +117,8 @@ def get_local_devices():
+     if _local_devices is None:
+         x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
+         sharded_x = map_to_device(x)
+-        _local_devices = [b.device() for b in sharded_x.device_buffers]
++        device_buffers = [buf.data for buf in sharded_x.addressable_shards]
++        _local_devices = [list(b.devices())[0] for b in device_buffers]
+     return _local_devices
+ 
+