Skip to content

Commit

Permalink
Nits
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 27, 2023
1 parent a051b5c commit 3638a4d
Showing 1 changed file with 103 additions and 101 deletions.
204 changes: 103 additions & 101 deletions keras_core/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
!!!DO NOT USE!!! Currently under development and APIs are not final.
Currently only the JAX backend has been implemented, and the Tensorflow backend
will be implemented in future (via tf.dtensor API).
Currently only the JAX backend has been implemented. The TensorFlow backend
will be implemented in the future (via tf.dtensor API).
"""

import collections
Expand All @@ -23,12 +23,13 @@
def list_devices(device_type=None):
"""Return all the available devices based on the device type.
Note that this should return the global devices in a distributed setting.
Note in a distributed setting, global devices are returned.
Args:
device_type: string of `"cpu"`, `"gpu"` or `"tpu"`. Default to `gpu` or
`tpu` if available when device_type is not provided. Otherwise
will return the `cpu` devices.
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`.
Default to `"gpu"` or `"tpu"` if available when
`device_type` is not provided. Otherwise
will return the `"cpu"` devices.
Return:
List of devices that are available for distribute computation.
Expand All @@ -37,15 +38,27 @@ def list_devices(device_type=None):


class DeviceMesh:
"""The cluster of computation devices for distributed computation.
"""A cluster of computation devices for distributed computation.
This is aligned with `jax.sharding.Mesh` and `tf.dtensor.Mesh`, which
This API is aligned with `jax.sharding.Mesh` and `tf.dtensor.Mesh`, which
represents the computation devices in the global context.
See more details in [jax.sharding.Mesh](
https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh)
and [tf.dtensor.Mesh](
https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh).
Args:
shape: tuple of list of integers. The shape of the overall
`DeviceMesh`, e.g. `(8,)` for a data parallel only distribution,
or `(4, 2)` for a model+data parallel distribution.
axis_names: List of string. The logical name of the each axis for
the `DeviceMesh`. The length of the `axis_names` should match to
the rank of the `shape`. The `axis_names` will be used to
match/create the `TensorLayout` when distribute the data and
variables.
devices: Optional list of devices. Default to all the available
devices locally from `list_devices()`.
"""

def __init__(
Expand All @@ -54,39 +67,25 @@ def __init__(
axis_names,
devices=None,
):
"""Initialize the DeviceMesh for the given topology.
Args:
shape: tuple of list of integers. The shape of the overall
DeviceMesh, e.g. `(8,)` for a data parallel only distribution,
or `(4, 2)` for a model+data parallel distribution.
axis_names: List of string. The logical name of the each axis for
the DeviceMesh. The length of the `axis_names` should match to
the rank of the `shape`. The `axis_names` will be used to
match/create the `TensorLayout` when distribute the data and
weights.
devices: Optional list of devices. Default to all the available
devices locally from `list_devices()`.
"""
if not shape or not axis_names:
raise ValueError(
"Shape and axis_names cannot be empty. Got "
f"shape: {shape}, axis_names: {axis_names}"
"Shape and axis_names cannot be empty. Received: "
f"shape={shape}, axis_names={axis_names}"
)

if len(shape) != len(axis_names):
raise ValueError(
"Shape and axis_names should have same size,"
f"got shape: {shape} and axis_names: {axis_names}"
"Shape and axis_names should have same size. "
f"Received: shape={shape}, axis_names={axis_names}"
)
if devices is None:
devices = list_devices()
devices = np.array(devices)
if np.prod(shape) != np.prod(devices.shape):
raise ValueError(
"Shape does not match the number of devices. "
f"Got shape: {shape}, and shape of the "
f"devices: {devices.shape}"
f"Received: shape={shape}; devices.shape="
f"{devices.shape}"
)

self._shape = shape
Expand All @@ -107,30 +106,26 @@ def devices(self):


class TensorLayout:
"""The layout of a Tensor.
"""The layout of a tensor.
This is aligned with `jax.sharding.NamedSharding` and `tf.dtensor.Layout`,
which allocate the tensor to its logic axis based on the `DeviceMesh`. With
`DeviceMesh` and `TensorLayout`, the actual mapping between a Tensor to the
physical devices can be determined.
This API is aligned with `jax.sharding.NamedSharding`
and `tf.dtensor.Layout`.
See more details in [jax.sharding.NamedSharding](
https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding)
and [tf.dtensor.Layout](
https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Layout).
Args:
axes: list of strings that should map to the `axis_names` in
`DeviceMesh`. For any dimentions that doesn't need any sharding,
A `None` can be used a placeholder.
device_mesh: Optional `DeviceMesh` that will be used to create
the layout. The actual mapping of tensor to physical device
is not known until the mesh is specified.
"""

def __init__(self, axes, device_mesh=None):
"""Initialize the TensorLayout with axis names.
Args:
axes: list of strings that should map to the `axis_names` in
`DeviceMesh`. For any dimentions that doesn't need any sharding,
A `None` can be used a placeholder.
device_mesh: Optional `DeviceMesh` that will be used to create
the layout. The actual mapping of tensor to physical device
is not known until the mesh is specified.
"""
self._axes = axes
self._device_mesh = device_mesh
self._validate_axes()
Expand Down Expand Up @@ -174,6 +169,9 @@ class Distribution:
It can create a context scope so that the framework to properly detect the
`Distribution` and distribute the variable/data accordingly.
Args:
device_mesh: A `DeviceMesh` instance.
"""

def __init__(self, device_mesh):
Expand All @@ -195,7 +193,7 @@ def get_variable_layout(self, variable):
"""Retrieve the `TensorLayout` for the variable.
Args:
variable: A `KerasVariable` to retrieve the `TensorLayout`.
variable: A `KerasVariable` instance.
return:
The `TensorLayout` for the variable, which can be used by
Expand All @@ -219,24 +217,29 @@ def device_mesh(self):


class DataParallel(Distribution):
def __init__(self, device_mesh=None, devices=None):
"""Create the data parallel distribution.
"""Distribution for data parallelism.
You can choose to create this instance by either specifing
the `device_mesh` or `devices` parameters (but not both).
You can choose to create this instance by either specifing
the `device_mesh` or `devices` arguments (but not both).
The device_mesh is expected to be a `DeviceMesh` instance, and is
expected to be 1D only. In case that the mesh has multiple axes, then
the first axis will be treated as the data parallel dimension
(and a warning will be raised).
The `device_mesh` argument is expected to be a `DeviceMesh` instance,
and is expected to be 1D only. In case that the mesh has multiple axes,
then the first axis will be treated as the data parallel dimension
(and a warning will be raised).
When a list of `devices` are provided, they will be used to construct a
1D mesh.
When a list of `devices` are provided, they will be used to construct a
1D mesh.
When both `mesh` and `devices` are absent, then `list_devices()`
will be used to detect any available devices and create a 1D mesh from
them.
"""
When both `mesh` and `devices` are absent, then `list_devices()`
will be used to detect any available devices and create a 1D mesh from
them.
Args:
device_mesh: Optional `DeviceMesh` instance.
devices: Optional list of devices.
"""

def __init__(self, device_mesh=None, devices=None):
if device_mesh:
self._initialize_with_device_mesh(device_mesh)
elif devices:
Expand Down Expand Up @@ -290,31 +293,33 @@ def get_variable_layout(self, variable):


class ModelParallel(Distribution):
"""Distribution that shard model weights.
"""Distribution that shards model variables.
Compare to DataParallel which replicates the weights across all the devices,
ModelParallel allows user to shard weights in addition to the input data.
Compare to `DataParallel` which replicates the variables across all devices,
`ModelParallel` allows you to shard variables in addition to the input data.
To construct a ModelParallel distribution, user need to provide device mesh
and layout mapping.
To construct a `ModelParallel` distribution, you need to provide a
`DeviceMesh` and a `LayoutMap`.
1. `DeviceMesh`contains physcial device information, and the axis names in
the mesh will be used to map the weight and data layout.
2. `LayoutMap` contains the mapping for the variable path to its
1. `DeviceMesh` contains physcial device information. The axis names in
the mesh will be used to map the variable and data layout.
2. `LayoutMap` contains the mapping between variable paths to their
corresponding `TensorLayout`.
Example:
```python
devices = list_devices() # Assume there are 8 devices.
# Create a mesh with 2 devices on data parallel and 4 devices on weight
# parallel.
# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# Create a layout map that shard the dense layer and conv2d layer weights
# on the last dimension. Based on the device_mesh, this means the weights
# will be split across 4 devices. Any other weights that doesn't match for
# any key in layout map will get be fully replicated.
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['.*dense.*kernel'] = TensorLayout([None, 'model'])
layout_map['.*dense.*bias'] = TensorLayout(['model'])
Expand All @@ -332,42 +337,40 @@ class ModelParallel(Distribution):
model.fit(data)
```
User can quickly update the device mesh shape to change the sharding factor
of the weights. E.g.
You can quickly update the device mesh shape to change the sharding factor
of the variables. E.g.
```
# With only the shape change for the device mesh, the weights will be
# sharded across 8 devices instead of 4, which further reduce the memory
# footprint of weights on each of the device.
# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'),
devices=devices)
```
To figure out a proper layout mapping rule for all the model weights, you
can first list out all the model weights path, which will be used as the key
to map the weights to `TensorLayout`.
To figure out a proper layout mapping rule for all the model variables, you
can first list out all the model variable paths, which will be used as the
key to map the variables to `TensorLayout`.
e.g.
```
model = create_model()
for w in model.weights:
print(w.path)
for v in model.variables:
print(v.path)
```
Args:
device_mesh: `DeviceMesh` instance for physical device and its
logical mapping.
layout_map: `LayoutMap` instance which map the variable path to the
corresponding `TensorLayout`. The axis names of the
`TensorLayout`s should match to the axis names in the
device_mesh, or exception will be raised.
batch_dim_name: optional string, the axis name in the `device_mesh`
that will be used to distribute data. If unspecified, the
first axis from the `device_mesh` will be used.
"""

def __init__(self, device_mesh, layout_map, batch_dim_name=None):
"""Initialize the model parallel distribution.
Args:
device_mesh: `DeviceMesh` instance for physical device and its
logical mapping.
layout_map: `LayoutMap` instance which map the variable path to the
corresponding `TensorLayout`. The axis names of the
`TensorLayout`s should match to the axis names in the
device_mesh, or exception will be raised.
batch_dim_name: optional string, the axis name in the device_mesh
that will be used to distribute data. The first axis from the
device_mesh will be used if user didn't specify any.
"""
super().__init__(device_mesh)
self._layout_map = layout_map
self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0]
Expand Down Expand Up @@ -419,24 +422,23 @@ class LayoutMap(collections.abc.MutableMapping):
Args:
device_mesh: An optional `DeviceMesh` that can be used to populate the
`TensorLayout.device_mesh` if the `TensorLayout.device_mesh` is not
set.
`TensorLayout.device_mesh` if `TensorLayout.device_mesh` is not set.
"""

def __init__(self, device_mesh=None):
self._layout_map = collections.OrderedDict()
self._device_mesh = device_mesh

def __getitem__(self, key):
"""Retrieve the corresponding layout by the string key.
"""Retrieves the corresponding layout by the string key.
When there isn't an exact match, all the existing keys in the layout map
will be treated as a regex and map against the input key again. The
first match will be returned, based on the key insertion order. Return
None if there isn't any match found.
first match will be returned, based on the key insertion order. Returns
`None` if there isn't any match found.
Args:
key: the string key as the query for the layout.
key: String key to query a layout.
Returns:
Corresponding layout based on the query.
Expand Down

0 comments on commit 3638a4d

Please sign in to comment.