Skip to content

Commit bb0a668

Browse files
authored
[hotfix] set return_outputs=False in examples and polish code (#5404)
* fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value
1 parent 5fcd779 commit bb0a668

File tree

24 files changed

+28
-36
lines changed

24 files changed

+28
-36
lines changed

applications/ColossalMoE/train.py

-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,6 @@ def main():
238238
lambda x, y: x.loss,
239239
optimizer,
240240
return_loss=True,
241-
return_outputs=True,
242241
)
243242
# Backward and optimize
244243
if is_pp_last_stage:

colossalai/booster/plugin/hybrid_parallel_plugin.py

+3
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,9 @@ def execute_pipeline(
11831183
) -> dict:
11841184
assert self.enable_pipeline_parallelism, "pipeline parallelism is not enabled"
11851185

1186+
if return_outputs:
1187+
warnings.warn("return_outputs may lead to significant extra memory consumption.")
1188+
11861189
# Create a context for gradient synchronization based on the optimizer type.
11871190
# If it's a HybridParallelZeroOptimizer, use optimizer.no_sync(); otherwise, use model.no_sync().
11881191
# This is to avoid redundant gradient reduction in pipeline parallelism (multiple microbatch values should be reduced once),

colossalai/pipeline/schedule/one_f_one_b.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils._pytree import tree_map
88

99
from colossalai.accelerator import get_accelerator
10-
from colossalai.interface import ModelWrapper, OptimizerWrapper
10+
from colossalai.interface import OptimizerWrapper
1111
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
1212
from colossalai.pipeline.stage_manager import PipelineStageManager
1313
from colossalai.utils import get_current_device
@@ -327,9 +327,7 @@ def run_forward_only(
327327
self.send_forward(output_obj)
328328

329329
if outputs is not None:
330-
if isinstance(model, ModelWrapper):
331-
model = model.unwrap()
332-
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
330+
outputs = merge_batch(outputs)
333331
return {"loss": accum_loss, "outputs": outputs}
334332

335333
def run_forward_backward(
@@ -412,9 +410,7 @@ def run_forward_backward(
412410
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
413411

414412
if outputs is not None:
415-
if isinstance(model, ModelWrapper):
416-
model = model.unwrap()
417-
outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0))
413+
outputs = merge_batch(outputs)
418414
return {"loss": accum_loss, "outputs": outputs}
419415

420416
def forward_backward_step(

docs/source/en/advanced_tutorials/train_gpt_using_hybrid_parallelism.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def train_epoch(
178178
for _ in pbar:
179179
if use_pipeline:
180180
outputs = booster.execute_pipeline(
181-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
181+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
182182
)
183183
# Backward and optimize
184184
if is_pp_last_stage:

docs/source/en/advanced_tutorials/train_vit_with_hybrid_parallelism.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def run_forward_backward(
231231
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
232232
# run pipeline forward backward when enabling pp in hybrid parallel plugin
233233
output_dict = booster.execute_pipeline(
234-
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True
234+
data_iter, model, criterion, optimizer, return_loss=True
235235
)
236236
loss, outputs = output_dict["loss"], output_dict["outputs"]
237237
else:

docs/source/en/features/pipeline_parallel.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
198198
model,
199199
_criterion,
200200
optimizer,
201-
return_loss=True,
202-
return_outputs=True)
201+
return_loss=True)
203202
# Backward and optimize
204203
if is_pp_last_stage:
205204
loss = outputs['loss']

docs/source/en/features/shardformer.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ However, if pipeline parallel is enabled, there are several usages different fro
271271
3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
272272
```python
273273
outputs = booster.execute_pipeline(
274-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
274+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
275275
)
276276
```
277277
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.

docs/source/zh-Hans/advanced_tutorials/train_gpt_using_hybrid_parallelism.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def train_epoch(
175175
for _ in pbar:
176176
if use_pipeline:
177177
outputs = booster.execute_pipeline(
178-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
178+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
179179
)
180180
# Backward and optimize
181181
if is_pp_last_stage:

docs/source/zh-Hans/advanced_tutorials/train_vit_with_hybrid_parallelism.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def run_forward_backward(
234234
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
235235
# run pipeline forward backward when enabling pp in hybrid parallel plugin
236236
output_dict = booster.execute_pipeline(
237-
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True
237+
data_iter, model, criterion, optimizer, return_loss=True
238238
)
239239
loss, outputs = output_dict["loss"], output_dict["outputs"]
240240
else:

docs/source/zh-Hans/features/pipeline_parallel.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
193193
model,
194194
_criterion,
195195
optimizer,
196-
return_loss=True,
197-
return_outputs=True)
196+
return_loss=True)
198197
# Backward and optimize
199198
if is_pp_last_stage:
200199
loss = outputs['loss']

docs/source/zh-Hans/features/shardformer.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ elif args.plugin == "hybrid_parallel":
264264
3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
265265
```python
266266
outputs = booster.execute_pipeline(
267-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
267+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
268268
)
269269
```
270270
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。

examples/images/vit/vit_benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def criterion(outputs, inputs):
120120
# run pipeline forward backward
121121
batch = iter([batch])
122122
outputs = booster.execute_pipeline(
123-
batch, model, criterion, optimizer, return_loss=True, return_outputs=True
123+
batch, model, criterion, optimizer, return_loss=True
124124
)
125125
else:
126126
outputs = model(**batch)

examples/language/bert/finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def train_epoch(
148148
for _ in pbar:
149149
if use_pipeline:
150150
outputs = booster.execute_pipeline(
151-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
151+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
152152
)
153153
# Backward and optimize
154154
if is_pp_last_device:

examples/language/gpt/hybridparallelism/finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def train_epoch(
145145
for _ in pbar:
146146
if use_pipeline:
147147
outputs = booster.execute_pipeline(
148-
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
148+
train_dataloader_iter, model, _criterion, optimizer, return_loss=True
149149
)
150150
# Backward and optimize
151151
if is_pp_last_stage:

examples/language/llama2/finetune.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def main():
271271
for step in pbar:
272272
if use_pipeline:
273273
outputs = booster.execute_pipeline(
274-
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
274+
dataloader_iter, model, _criterion, optimizer, return_loss=True
275275
)
276276
loss = outputs["loss"]
277277
else:

examples/language/llama2/pretrain.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def main():
185185
microbatch_size=1,
186186
enable_jit_fused=False,
187187
zero_stage=0,
188-
precision="fp32",
188+
precision=args.mixed_precision,
189189
initial_scale=1,
190190
)
191191
else:
@@ -286,7 +286,7 @@ def main():
286286
for step in pbar:
287287
if use_pipeline:
288288
outputs = booster.execute_pipeline(
289-
dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
289+
dataloader_iter, model, _criterion, optimizer, return_loss=True
290290
)
291291
loss = outputs["loss"]
292292
else:

examples/language/openmoe/benchmark/benchmark_cai.py

-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ def main():
270270
lambda x, y: x.loss,
271271
optimizer,
272272
return_loss=True,
273-
return_outputs=True,
274273
)
275274
# Backward and optimize
276275
if is_pp_last_stage:

examples/language/openmoe/train.py

-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@ def main():
340340
lambda x, y: x.loss,
341341
optimizer,
342342
return_loss=True,
343-
return_outputs=True,
344343
)
345344
# Backward and optimize
346345
if is_pp_last_stage:

examples/language/opt/opt_train_demo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, b
4242
for _ in pbar:
4343
if use_pipeline:
4444
outputs = booster.execute_pipeline(
45-
dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True
45+
dataloader, model, _criterion, optimizer, return_loss=True
4646
)
4747
# Backward and optimize
4848
if is_pp_last_stage:

tests/test_booster/test_plugin/test_3d_plugin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _criterion(outputs, inputs):
7474
loss = criterion(outputs[output_key])
7575
return loss
7676

77-
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True, return_outputs=False)
77+
booster.execute_pipeline(data_iter, model, _criterion, optimizer, return_loss=True)
7878
optimizer.step()
7979

8080
except Exception as e:

tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _preprocess_data(data):
7575
model.train()
7676
if booster.plugin.stage_manager is not None:
7777
booster.execute_pipeline(
78-
_preprocess_data(data), model, _criterion, optimizer, return_loss=True, return_outputs=False
78+
_preprocess_data(data), model, _criterion, optimizer, return_loss=True
7979
)
8080
else:
8181
output = model(**_preprocess_data(data))
@@ -109,15 +109,14 @@ def _preprocess_data(data):
109109
data_for_origin = data_gen_fn()
110110
if booster.plugin.stage_manager is not None:
111111
booster.execute_pipeline(
112-
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True, return_outputs=False
112+
_preprocess_data(data_for_shard), model, _criterion, optimizer, return_loss=True
113113
)
114114
booster.execute_pipeline(
115115
_preprocess_data(data_for_origin),
116116
new_model,
117117
_criterion,
118118
new_optimizer,
119119
return_loss=True,
120-
return_outputs=False,
121120
)
122121
else:
123122
old_model_loss = criterion(model(**_preprocess_data(data_for_shard)))

tests/test_moe/test_moe_checkpoint.py

-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def run_fwd_bwd(
4949
lambda x, y: x.loss,
5050
optimizer,
5151
return_loss=True,
52-
return_outputs=True,
5352
)
5453
# Backward and optimize
5554
if is_pp_last_stage:

tests/test_pipeline/test_schedule/test_interleaved.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def criterion(x, *args, **kwargs):
104104
torch_loss.backward()
105105

106106
pp_ret = schedule.forward_backward_step(
107-
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
107+
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
108108
)
109109

110110
# check loss
@@ -134,7 +134,7 @@ def criterion(x, *args, **kwargs):
134134
torch_loss = criterion(torch_output)
135135

136136
pp_ret = schedule.forward_backward_step(
137-
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
137+
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
138138
)
139139
if stage_manager.is_last_stage(ignore_chunk=True):
140140
assert torch.allclose(torch_loss, pp_ret["loss"])

tests/test_pipeline/test_schedule/test_oneF_oneB.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def custom_fwd(self, x):
100100
torch_loss = criterion(torch_output)
101101
torch_loss.backward()
102102
pp_ret = schedule.forward_backward_step(
103-
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
103+
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
104104
)
105105

106106
# check loss
@@ -130,7 +130,7 @@ def custom_fwd(self, x):
130130
torch_loss = criterion(torch_output)
131131

132132
pp_ret = schedule.forward_backward_step(
133-
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
133+
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True
134134
)
135135
if stage_manager.is_last_stage():
136136
assert torch.allclose(torch_loss, pp_ret["loss"])

0 commit comments

Comments
 (0)