Skip to content

Commit

Permalink
Fix the distribution logic for the JAX trainer. (#832)
Browse files Browse the repository at this point in the history
* Update the sharding logic for jax trainer.

* Fix the corner case that fail the unit tests.

1. eval() should also do a build for uninitialized metrics.
2. model.metrics should take care of uncompiled model.

* Address review comments.
  • Loading branch information
qlzh727 authored Sep 3, 2023
1 parent de510e9 commit 2173cbb
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 26 deletions.
161 changes: 137 additions & 24 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def train_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
grad_fn = jax.value_and_grad(
self.compute_loss_and_updates, has_aux=True
)
Expand Down Expand Up @@ -116,7 +115,7 @@ def train_step(self, state, data):
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables

state = (
state = self._enforce_jax_state_sharding(
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand All @@ -131,7 +130,6 @@ def test_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
loss, (
y_pred,
non_trainable_variables,
Expand Down Expand Up @@ -161,6 +159,17 @@ def test_step(self, state, data):
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -175,10 +184,20 @@ def predict_step(self, state, data):
kwargs["training"] = False

x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._distribute_data(x)
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
trainable_variables,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, (trainable_variables, non_trainable_variables)

def make_train_function(self, force=False):
Expand Down Expand Up @@ -356,6 +375,7 @@ def fit(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_train_function()
self.stop_training = False
Expand All @@ -365,10 +385,12 @@ def fit(
self.reset_metrics()
callbacks.on_epoch_begin(epoch)

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]

for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
Expand All @@ -381,6 +403,7 @@ def fit(
optimizer_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.train_function(state, data)
(
trainable_variables,
Expand Down Expand Up @@ -490,7 +513,13 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

if not all(layer.built for layer in self._flatten_layers()):
needs_building = not all(
layer.built for layer in self._flatten_layers()
) or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
Expand All @@ -508,15 +537,18 @@ def evaluate(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
metrics_variables = [v.value for v in self.metrics_variables]

for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
Expand All @@ -526,6 +558,7 @@ def evaluate(
non_trainable_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.test_function(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
Expand Down Expand Up @@ -584,6 +617,7 @@ def predict(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_predict_function()
callbacks.on_predict_begin()
Expand All @@ -603,12 +637,15 @@ def append_to_outputs(batch_outputs, outputs):
)
return outputs

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
state = (trainable_variables, non_trainable_variables)
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
x = self._distribute_data(x)
batch_outputs, state = self.predict_function(state, x)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
Expand Down Expand Up @@ -636,16 +673,20 @@ def train_on_batch(
y, class_weight
)
data = (x, y, sample_weight)
data = self._distribute_data(data)

# Maybe build model
self._eager_build(data)
self._record_training_state_sharding_spec()
self.make_train_function()

# Train step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]
state = (
trainable_variables,
non_trainable_variables,
Expand Down Expand Up @@ -685,14 +726,18 @@ def test_on_batch(
self._assert_compile_called("test_on_batch")

data = (x, y, sample_weight)
data = self._distribute_data(data)
# Maybe build model
self._eager_build(data)
self._record_training_state_sharding_spec()
self.make_test_function()

# Test step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
metrics_variables = [v.value for v in self.metrics_variables]
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -719,10 +764,13 @@ def predict_on_batch(self, x):
# Build model
with backend.StatelessScope():
self(x)

self._record_training_state_sharding_spec()
self.make_predict_function()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables

trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
state = (trainable_variables, non_trainable_variables)
batch_outputs, state = self.predict_function(state, [(x,)])
batch_outputs = tree.map_structure(lambda x: np.array(x), batch_outputs)
Expand Down Expand Up @@ -764,3 +812,68 @@ def distribute_single_value(d):
return jax.tree_util.tree_map(distribute_single_value, data)
else:
return data

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
v.value.sharding for v in self.trainable_variables
]
self._non_trainable_variable_shardings = [
v.value.sharding for v in self.non_trainable_variables
]
if hasattr(self, "optimizer"):
self._optimizer_variable_shardings = [
v.value.sharding for v in self.optimizer.variables
]
else:
self._optimizer_variable_shardings = []
self._metrics_variable_shardings = [
v.value.sharding for v in self.metrics_variables
]

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.
Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
jax.jit won't have to recompile the train/eval function.
Note that this function will also rely on the recorded sharding spec
for each of states.
This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
4 changes: 2 additions & 2 deletions keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def run_eagerly(self, value):

@property
def metrics(self):
metrics = [self._loss_tracker]
metrics = [self._loss_tracker] if self.compiled else []
metrics.extend(self._metrics[:])
if self._compile_metrics is not None:
if self.compiled and self._compile_metrics is not None:
metrics += [self._compile_metrics]
return metrics

Expand Down

0 comments on commit 2173cbb

Please sign in to comment.