diff --git a/objax/util/util.py b/objax/util/util.py index c31a356..344cf9a 100644 --- a/objax/util/util.py +++ b/objax/util/util.py @@ -117,7 +117,8 @@ def get_local_devices(): if _local_devices is None: x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) sharded_x = map_to_device(x) - _local_devices = [b.device() for b in sharded_x.device_buffers] + device_buffers = [buf.data for buf in sharded_x.addressable_shards] + _local_devices = [list(b.devices())[0] for b in device_buffers] return _local_devices