Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the distribution logic for the JAX trainer. #832

Merged
merged 3 commits into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 137 additions & 24 deletions keras_core/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def train_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
grad_fn = jax.value_and_grad(
self.compute_loss_and_updates, has_aux=True
)
Expand Down Expand Up @@ -116,7 +115,7 @@ def train_step(self, state, data):
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables

state = (
state = self._enforce_jax_state_sharding(
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand All @@ -131,7 +130,6 @@ def test_step(self, state, data):
metrics_variables,
) = state
x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)
x, y, sample_weight = self._distribute_data((x, y, sample_weight))
loss, (
y_pred,
non_trainable_variables,
Expand Down Expand Up @@ -161,6 +159,17 @@ def test_step(self, state, data):
new_metrics_variables.append(new_v)
metrics_variables = new_metrics_variables

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -175,10 +184,20 @@ def predict_step(self, state, data):
kwargs["training"] = False

x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)
x = self._distribute_data(x)
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
trainable_variables,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, (trainable_variables, non_trainable_variables)

def make_train_function(self, force=False):
Expand Down Expand Up @@ -356,6 +375,7 @@ def fit(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_train_function()
self.stop_training = False
Expand All @@ -365,10 +385,12 @@ def fit(
self.reset_metrics()
callbacks.on_epoch_begin(epoch)

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]

for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Callbacks
Expand All @@ -381,6 +403,7 @@ def fit(
optimizer_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.train_function(state, data)
(
trainable_variables,
Expand Down Expand Up @@ -490,7 +513,13 @@ def evaluate(
steps_per_execution=self.steps_per_execution,
)

if not all(layer.built for layer in self._flatten_layers()):
needs_building = not all(
layer.built for layer in self._flatten_layers()
) or (
self._compile_metrics is not None
and not self._compile_metrics.built
)
if needs_building:
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
data_batch = data[0]
Expand All @@ -508,15 +537,18 @@ def evaluate(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
callbacks.on_test_begin()
logs = None
self.reset_metrics()

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
metrics_variables = [v.value for v in self.metrics_variables]

for step, data in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_test_batch_begin(step)
Expand All @@ -526,6 +558,7 @@ def evaluate(
non_trainable_variables,
metrics_variables,
)
data = self._distribute_data(data)
logs, state = self.test_function(state, data)
# Note that trainable variables are not returned since they're
# immutable here.
Expand Down Expand Up @@ -584,6 +617,7 @@ def predict(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_predict_function()
callbacks.on_predict_begin()
Expand All @@ -603,12 +637,15 @@ def append_to_outputs(batch_outputs, outputs):
)
return outputs

trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
state = (trainable_variables, non_trainable_variables)
outputs = None
for step, x in epoch_iterator.enumerate_epoch(return_type="np"):
callbacks.on_predict_batch_begin(step)
x = self._distribute_data(x)
batch_outputs, state = self.predict_function(state, x)
outputs = append_to_outputs(batch_outputs, outputs)
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
Expand Down Expand Up @@ -636,16 +673,20 @@ def train_on_batch(
y, class_weight
)
data = (x, y, sample_weight)
data = self._distribute_data(data)

# Maybe build model
self._eager_build(data)
self._record_training_state_sharding_spec()
self.make_train_function()

# Train step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
optimizer_variables = self.optimizer.variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
optimizer_variables = [v.value for v in self.optimizer.variables]
metrics_variables = [v.value for v in self.metrics_variables]
state = (
trainable_variables,
non_trainable_variables,
Expand Down Expand Up @@ -685,14 +726,18 @@ def test_on_batch(
self._assert_compile_called("test_on_batch")

data = (x, y, sample_weight)
data = self._distribute_data(data)
# Maybe build model
self._eager_build(data)
self._record_training_state_sharding_spec()
self.make_test_function()

# Test step
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables
metrics_variables = self.metrics_variables
trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
metrics_variables = [v.value for v in self.metrics_variables]
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -719,10 +764,13 @@ def predict_on_batch(self, x):
# Build model
with backend.StatelessScope():
self(x)

self._record_training_state_sharding_spec()
self.make_predict_function()
trainable_variables = self.trainable_variables
non_trainable_variables = self.non_trainable_variables

trainable_variables = [v.value for v in self.trainable_variables]
non_trainable_variables = [
v.value for v in self.non_trainable_variables
]
state = (trainable_variables, non_trainable_variables)
batch_outputs, state = self.predict_function(state, [(x,)])
batch_outputs = tree.map_structure(lambda x: np.array(x), batch_outputs)
Expand Down Expand Up @@ -764,3 +812,68 @@ def distribute_single_value(d):
return jax.tree_util.tree_map(distribute_single_value, data)
else:
return data

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
v.value.sharding for v in self.trainable_variables
]
self._non_trainable_variable_shardings = [
v.value.sharding for v in self.non_trainable_variables
]
if hasattr(self, "optimizer"):
self._optimizer_variable_shardings = [
v.value.sharding for v in self.optimizer.variables
]
else:
self._optimizer_variable_shardings = []
self._metrics_variable_shardings = [
v.value.sharding for v in self.metrics_variables
]

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)
4 changes: 2 additions & 2 deletions keras_core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def run_eagerly(self, value):

@property
def metrics(self):
metrics = [self._loss_tracker]
metrics = [self._loss_tracker] if self.compiled else []
metrics.extend(self._metrics[:])
if self._compile_metrics is not None:
if self.compiled and self._compile_metrics is not None:
metrics += [self._compile_metrics]
return metrics

Expand Down
Loading