about summary refs log tree commit diff
path: root/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch
blob: fc0fd50a90ce882403bcf483a6567a0b7680d259 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
diff --git a/objax/util/util.py b/objax/util/util.py
index c31a356..344cf9a 100644
--- a/objax/util/util.py
+++ b/objax/util/util.py
@@ -117,7 +117,8 @@ def get_local_devices():
     if _local_devices is None:
         x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32)
         sharded_x = map_to_device(x)
-        _local_devices = [b.device() for b in sharded_x.device_buffers]
+        device_buffers = [buf.data for buf in sharded_x.addressable_shards]
+        _local_devices = [list(b.devices())[0] for b in device_buffers]
     return _local_devices