Skip to content

Commit

Permalink
Export distribution lib to public API
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 4, 2023
1 parent 0d295ad commit 2956edb
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions keras_core/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@

import numpy as np

from keras_core.api_export import keras_core_export
from keras_core.backend import distribution_lib
from keras_core.backend.common import global_state

DEFAULT_BATCH_DIM_NAME = "batch"
GLOBAL_ATTRIBUTE_NAME = "distribution"


@keras_core_export("keras_core.distribution.list_devices")
def list_devices(device_type=None):
"""Return all the available devices based on the device type.
Note in a distributed setting, global devices are returned.
Note: in a distributed setting, global devices are returned.
Args:
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`.
Expand All @@ -37,6 +39,7 @@ def list_devices(device_type=None):
return distribution_lib.list_devices(device_type)


@keras_core_export("keras_core.distribution.DeviceMesh")
class DeviceMesh:
"""A cluster of computation devices for distributed computation.
Expand All @@ -58,7 +61,7 @@ class DeviceMesh:
match/create the `TensorLayout` when distribute the data and
variables.
devices: Optional list of devices. Default to all the available
devices locally from `list_devices()`.
devices locally from `keras_core.distribution.list_devices()`.
"""

def __init__(
Expand Down Expand Up @@ -105,8 +108,9 @@ def devices(self):
return self._devices


@keras_core_export("keras_core.distribution.TensorLayout")
class TensorLayout:
"""The layout of a tensor.
"""A layout to apply to a tensor.
This API is aligned with `jax.sharding.NamedSharding`
and `tf.dtensor.Layout`.
Expand All @@ -118,7 +122,7 @@ class TensorLayout:
Args:
axes: list of strings that should map to the `axis_names` in
`DeviceMesh`. For any dimentions that doesn't need any sharding,
a `DeviceMesh`. For any dimentions that doesn't need any sharding,
A `None` can be used a placeholder.
device_mesh: Optional `DeviceMesh` that will be used to create
the layout. The actual mapping of tensor to physical device
Expand Down Expand Up @@ -160,12 +164,12 @@ def _validate_axes(self):


class Distribution:
"""Base class for the distribution.
"""Base class for variable distribution strategies.
The `Distribution` has following key functionalities.
A `Distribution` has following key functionalities:
1. Distribute the model variables to the `DeviceMesh`.
2. Distribute the input data to the `DeviceMesh`.
1. Distribute the model variables to a `DeviceMesh`.
2. Distribute the input data to a `DeviceMesh`.
It can create a context scope so that the framework to properly detect the
`Distribution` and distribute the variable/data accordingly.
Expand Down Expand Up @@ -216,6 +220,7 @@ def device_mesh(self):
return self._device_mesh


@keras_core_export("keras_core.distribution.DataParallel")
class DataParallel(Distribution):
"""Distribution for data parallelism.
Expand Down Expand Up @@ -292,6 +297,7 @@ def get_variable_layout(self, variable):
return TensorLayout(variable_shard_spec, self.device_mesh)


@keras_core_export("keras_core.distribution.ModelParallel")
class ModelParallel(Distribution):
"""Distribution that shards model variables.
Expand Down Expand Up @@ -388,6 +394,7 @@ def get_variable_layout(self, variable):
return TensorLayout(variable_shard_spec, self.device_mesh)


@keras_core_export("keras_core.distribution.LayoutMap")
class LayoutMap(collections.abc.MutableMapping):
"""A dict-like object that maps string to `TensorLayout` instances.
Expand Down Expand Up @@ -484,11 +491,13 @@ def _maybe_populate_device_mesh(self, layout):
layout.device_mesh = self.device_mesh


@keras_core_export("keras_core.distribution.distribution")
def distribution():
"""Retrieve the current distribution from global context."""
return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME)


@keras_core_export("keras_core.distribution.set_distribution")
def set_distribution(value):
"""Set the distribution as the global distribution setting.
Expand Down

0 comments on commit 2956edb

Please sign in to comment.