From 3638a4db6428f6dd90bd540d423fe5d7efd82181 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Sat, 26 Aug 2023 19:41:53 -0700 Subject: [PATCH] Nits --- keras_core/distribution/distribution_lib.py | 204 ++++++++++---------- 1 file changed, 103 insertions(+), 101 deletions(-) diff --git a/keras_core/distribution/distribution_lib.py b/keras_core/distribution/distribution_lib.py index 15d5abd08..893614154 100644 --- a/keras_core/distribution/distribution_lib.py +++ b/keras_core/distribution/distribution_lib.py @@ -2,8 +2,8 @@ !!!DO NOT USE!!! Currently under development and APIs are not final. -Currently only the JAX backend has been implemented, and the Tensorflow backend -will be implemented in future (via tf.dtensor API). +Currently only the JAX backend has been implemented. The TensorFlow backend +will be implemented in the future (via tf.dtensor API). """ import collections @@ -23,12 +23,13 @@ def list_devices(device_type=None): """Return all the available devices based on the device type. - Note that this should return the global devices in a distributed setting. + Note in a distributed setting, global devices are returned. Args: - device_type: string of `"cpu"`, `"gpu"` or `"tpu"`. Default to `gpu` or - `tpu` if available when device_type is not provided. Otherwise - will return the `cpu` devices. + device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. + Default to `"gpu"` or `"tpu"` if available when + `device_type` is not provided. Otherwise + will return the `"cpu"` devices. Return: List of devices that are available for distribute computation. @@ -37,15 +38,27 @@ def list_devices(device_type=None): class DeviceMesh: - """The cluster of computation devices for distributed computation. + """A cluster of computation devices for distributed computation. - This is aligned with `jax.sharding.Mesh` and `tf.dtensor.Mesh`, which + This API is aligned with `jax.sharding.Mesh` and `tf.dtensor.Mesh`, which represents the computation devices in the global context. See more details in [jax.sharding.Mesh]( https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) and [tf.dtensor.Mesh]( https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh). + + Args: + shape: tuple of list of integers. The shape of the overall + `DeviceMesh`, e.g. `(8,)` for a data parallel only distribution, + or `(4, 2)` for a model+data parallel distribution. + axis_names: List of string. The logical name of the each axis for + the `DeviceMesh`. The length of the `axis_names` should match to + the rank of the `shape`. The `axis_names` will be used to + match/create the `TensorLayout` when distribute the data and + variables. + devices: Optional list of devices. Default to all the available + devices locally from `list_devices()`. """ def __init__( @@ -54,30 +67,16 @@ def __init__( axis_names, devices=None, ): - """Initialize the DeviceMesh for the given topology. - - Args: - shape: tuple of list of integers. The shape of the overall - DeviceMesh, e.g. `(8,)` for a data parallel only distribution, - or `(4, 2)` for a model+data parallel distribution. - axis_names: List of string. The logical name of the each axis for - the DeviceMesh. The length of the `axis_names` should match to - the rank of the `shape`. The `axis_names` will be used to - match/create the `TensorLayout` when distribute the data and - weights. - devices: Optional list of devices. Default to all the available - devices locally from `list_devices()`. - """ if not shape or not axis_names: raise ValueError( - "Shape and axis_names cannot be empty. Got " - f"shape: {shape}, axis_names: {axis_names}" + "Shape and axis_names cannot be empty. Received: " + f"shape={shape}, axis_names={axis_names}" ) if len(shape) != len(axis_names): raise ValueError( - "Shape and axis_names should have same size," - f"got shape: {shape} and axis_names: {axis_names}" + "Shape and axis_names should have same size. " + f"Received: shape={shape}, axis_names={axis_names}" ) if devices is None: devices = list_devices() @@ -85,8 +84,8 @@ def __init__( if np.prod(shape) != np.prod(devices.shape): raise ValueError( "Shape does not match the number of devices. " - f"Got shape: {shape}, and shape of the " - f"devices: {devices.shape}" + f"Received: shape={shape}; devices.shape=" + f"{devices.shape}" ) self._shape = shape @@ -107,30 +106,26 @@ def devices(self): class TensorLayout: - """The layout of a Tensor. + """The layout of a tensor. - This is aligned with `jax.sharding.NamedSharding` and `tf.dtensor.Layout`, - which allocate the tensor to its logic axis based on the `DeviceMesh`. With - `DeviceMesh` and `TensorLayout`, the actual mapping between a Tensor to the - physical devices can be determined. + This API is aligned with `jax.sharding.NamedSharding` + and `tf.dtensor.Layout`. See more details in [jax.sharding.NamedSharding]( https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) and [tf.dtensor.Layout]( https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Layout). + + Args: + axes: list of strings that should map to the `axis_names` in + `DeviceMesh`. For any dimentions that doesn't need any sharding, + A `None` can be used a placeholder. + device_mesh: Optional `DeviceMesh` that will be used to create + the layout. The actual mapping of tensor to physical device + is not known until the mesh is specified. """ def __init__(self, axes, device_mesh=None): - """Initialize the TensorLayout with axis names. - - Args: - axes: list of strings that should map to the `axis_names` in - `DeviceMesh`. For any dimentions that doesn't need any sharding, - A `None` can be used a placeholder. - device_mesh: Optional `DeviceMesh` that will be used to create - the layout. The actual mapping of tensor to physical device - is not known until the mesh is specified. - """ self._axes = axes self._device_mesh = device_mesh self._validate_axes() @@ -174,6 +169,9 @@ class Distribution: It can create a context scope so that the framework to properly detect the `Distribution` and distribute the variable/data accordingly. + + Args: + device_mesh: A `DeviceMesh` instance. """ def __init__(self, device_mesh): @@ -195,7 +193,7 @@ def get_variable_layout(self, variable): """Retrieve the `TensorLayout` for the variable. Args: - variable: A `KerasVariable` to retrieve the `TensorLayout`. + variable: A `KerasVariable` instance. return: The `TensorLayout` for the variable, which can be used by @@ -219,24 +217,29 @@ def device_mesh(self): class DataParallel(Distribution): - def __init__(self, device_mesh=None, devices=None): - """Create the data parallel distribution. + """Distribution for data parallelism. - You can choose to create this instance by either specifing - the `device_mesh` or `devices` parameters (but not both). + You can choose to create this instance by either specifing + the `device_mesh` or `devices` arguments (but not both). - The device_mesh is expected to be a `DeviceMesh` instance, and is - expected to be 1D only. In case that the mesh has multiple axes, then - the first axis will be treated as the data parallel dimension - (and a warning will be raised). + The `device_mesh` argument is expected to be a `DeviceMesh` instance, + and is expected to be 1D only. In case that the mesh has multiple axes, + then the first axis will be treated as the data parallel dimension + (and a warning will be raised). - When a list of `devices` are provided, they will be used to construct a - 1D mesh. + When a list of `devices` are provided, they will be used to construct a + 1D mesh. - When both `mesh` and `devices` are absent, then `list_devices()` - will be used to detect any available devices and create a 1D mesh from - them. - """ + When both `mesh` and `devices` are absent, then `list_devices()` + will be used to detect any available devices and create a 1D mesh from + them. + + Args: + device_mesh: Optional `DeviceMesh` instance. + devices: Optional list of devices. + """ + + def __init__(self, device_mesh=None, devices=None): if device_mesh: self._initialize_with_device_mesh(device_mesh) elif devices: @@ -290,31 +293,33 @@ def get_variable_layout(self, variable): class ModelParallel(Distribution): - """Distribution that shard model weights. + """Distribution that shards model variables. - Compare to DataParallel which replicates the weights across all the devices, - ModelParallel allows user to shard weights in addition to the input data. + Compare to `DataParallel` which replicates the variables across all devices, + `ModelParallel` allows you to shard variables in addition to the input data. - To construct a ModelParallel distribution, user need to provide device mesh - and layout mapping. + To construct a `ModelParallel` distribution, you need to provide a + `DeviceMesh` and a `LayoutMap`. - 1. `DeviceMesh`contains physcial device information, and the axis names in - the mesh will be used to map the weight and data layout. - 2. `LayoutMap` contains the mapping for the variable path to its + 1. `DeviceMesh` contains physcial device information. The axis names in + the mesh will be used to map the variable and data layout. + 2. `LayoutMap` contains the mapping between variable paths to their corresponding `TensorLayout`. Example: + ```python devices = list_devices() # Assume there are 8 devices. - # Create a mesh with 2 devices on data parallel and 4 devices on weight - # parallel. + # Create a mesh with 2 devices for data parallelism and 4 devices for + # model parallelism. device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'), devices=devices) - # Create a layout map that shard the dense layer and conv2d layer weights - # on the last dimension. Based on the device_mesh, this means the weights - # will be split across 4 devices. Any other weights that doesn't match for - # any key in layout map will get be fully replicated. + # Create a layout map that shard the `Dense` layer and `Conv2D` + # layer variables on the last dimension. + # Based on the `device_mesh`, this means the variables + # will be split across 4 devices. Any other variable that doesn't + # match any key in the layout map will be fully replicated. layout_map = LayoutMap(device_mesh) layout_map['.*dense.*kernel'] = TensorLayout([None, 'model']) layout_map['.*dense.*bias'] = TensorLayout(['model']) @@ -332,42 +337,40 @@ class ModelParallel(Distribution): model.fit(data) ``` - User can quickly update the device mesh shape to change the sharding factor - of the weights. E.g. + You can quickly update the device mesh shape to change the sharding factor + of the variables. E.g. ``` - # With only the shape change for the device mesh, the weights will be - # sharded across 8 devices instead of 4, which further reduce the memory - # footprint of weights on each of the device. + # With only the shape change for the device mesh, the variables will be + # sharded across 8 devices instead of 4, which further reduces the memory + # footprint of variables on each of the device. device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'), devices=devices) ``` - To figure out a proper layout mapping rule for all the model weights, you - can first list out all the model weights path, which will be used as the key - to map the weights to `TensorLayout`. + To figure out a proper layout mapping rule for all the model variables, you + can first list out all the model variable paths, which will be used as the + key to map the variables to `TensorLayout`. e.g. ``` model = create_model() - for w in model.weights: - print(w.path) + for v in model.variables: + print(v.path) ``` + + Args: + device_mesh: `DeviceMesh` instance for physical device and its + logical mapping. + layout_map: `LayoutMap` instance which map the variable path to the + corresponding `TensorLayout`. The axis names of the + `TensorLayout`s should match to the axis names in the + device_mesh, or exception will be raised. + batch_dim_name: optional string, the axis name in the `device_mesh` + that will be used to distribute data. If unspecified, the + first axis from the `device_mesh` will be used. """ def __init__(self, device_mesh, layout_map, batch_dim_name=None): - """Initialize the model parallel distribution. - - Args: - device_mesh: `DeviceMesh` instance for physical device and its - logical mapping. - layout_map: `LayoutMap` instance which map the variable path to the - corresponding `TensorLayout`. The axis names of the - `TensorLayout`s should match to the axis names in the - device_mesh, or exception will be raised. - batch_dim_name: optional string, the axis name in the device_mesh - that will be used to distribute data. The first axis from the - device_mesh will be used if user didn't specify any. - """ super().__init__(device_mesh) self._layout_map = layout_map self._batch_dim_name = batch_dim_name or self.device_mesh.axis_names[0] @@ -419,8 +422,7 @@ class LayoutMap(collections.abc.MutableMapping): Args: device_mesh: An optional `DeviceMesh` that can be used to populate the - `TensorLayout.device_mesh` if the `TensorLayout.device_mesh` is not - set. + `TensorLayout.device_mesh` if `TensorLayout.device_mesh` is not set. """ def __init__(self, device_mesh=None): @@ -428,15 +430,15 @@ def __init__(self, device_mesh=None): self._device_mesh = device_mesh def __getitem__(self, key): - """Retrieve the corresponding layout by the string key. + """Retrieves the corresponding layout by the string key. When there isn't an exact match, all the existing keys in the layout map will be treated as a regex and map against the input key again. The - first match will be returned, based on the key insertion order. Return - None if there isn't any match found. + first match will be returned, based on the key insertion order. Returns + `None` if there isn't any match found. Args: - key: the string key as the query for the layout. + key: String key to query a layout. Returns: Corresponding layout based on the query.