diff options
Diffstat (limited to 'pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch')
-rw-r--r-- | pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch b/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch new file mode 100644 index 0000000000000..fc0fd50a90ce8 --- /dev/null +++ b/pkgs/development/python-modules/objax/replace-deprecated-device_buffers.patch @@ -0,0 +1,14 @@ +diff --git a/objax/util/util.py b/objax/util/util.py +index c31a356..344cf9a 100644 +--- a/objax/util/util.py ++++ b/objax/util/util.py +@@ -117,7 +117,8 @@ def get_local_devices(): + if _local_devices is None: + x = jn.zeros((jax.local_device_count(), 1), dtype=jn.float32) + sharded_x = map_to_device(x) +- _local_devices = [b.device() for b in sharded_x.device_buffers] ++ device_buffers = [buf.data for buf in sharded_x.addressable_shards] ++ _local_devices = [list(b.devices())[0] for b in device_buffers] + return _local_devices + + |