Skip to content

Commit

Permalink
[feature] Add support for OffloadModel to enable training large model…
Browse files Browse the repository at this point in the history
…s on 1 GPU. (#432)

* clean start

* removing per layer split strategy, probably not that useful indeed

* initial transformer benchmark

* hack, enable testing ViT + offload, python3 benchmarks/oss.py  --epochs 2 --optim_type oss_offload_ddp --batch_size=32 --model vit_large_patch16_224

* proper cuda streams and device, something off in terms of mems consumption

* minor, stashing

* unit test fix

* removing all the distributed parts

* simpler test, needs debugging

* working OOP, running a model which does not fit on the gpu memory

* spring cleaning

* removing the ill-advised optimizer bits, better keep that orthogonal

* [offload] Add support for activation offloading + other changes (#367)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* avoid saving inputs

* fix lint errors

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>

* [offload] Add support for fp16 training (#374)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>

* [offload] Add support for activation checkpointing for all layers. (#381)

* initial fwd/bwd commit

* checkpoint work

* modify shard loop

* activation offloading and test to start with

* fix lint errors

* update comments

* fix lint

* remove unused var

* remove commented out lines

* modify name

* remove break

* remove profiler comments

* add support for fp16

* add unit tests

* fix lint errors

* fix test failure

* cp work, incorrect output dimensions still need to be fixed

* fixed activation outputs

* intermediate cp of work

* add tests

* fix lint errors

Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>

* add support for microbatches

* revert benchmark config changes

* add parametrization

* fix lint errors and tests

* skip test for 1.5

* fix lint errors

* skip test if there are no GPUs

* fix lint errors

* fix lint errors

* move experimental to the fairscale repo

* lint error fixes

* modify test imports

* lint error fixes

* move offload files to the experimental directory

* move tests and benchmarks to their forlder

* fix mypy errors

* cp intermediate working benchmarks

* more changes

* split benchmark configs

* remove print statements

* fix lint errors

* remove unused print

* stress testing

* remove unused file

* change param nae

* lint fixes

* move file to the right folder

* offload_experimental

* add doc string

* add error message

Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@gmail.com>
Co-authored-by: Benjamin Lefaudeux <benjamin.lefaudeux@protonmail.com>
Co-authored-by: Anjali Sridhar <anj@devfair0443.h2.fair>
  • Loading branch information
4 people authored Feb 26, 2021
1 parent 7ee228b commit f7813d6
Show file tree
Hide file tree
Showing 6 changed files with 1,110 additions and 39 deletions.
2 changes: 1 addition & 1 deletion benchmarks/datasets/wikitext2_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def data_process(raw_text_iter):
test_dataset = data_process(iter(io.open(test_filepath, encoding="utf8")))

def batchify(data):
batch_size = args.batch_size
batch_size = benchmark_config["batch_size"]
return _batchify(data, batch_size)

total_batch_size = _get_total_batch_size(benchmark_config, model_specs)
Expand Down
Loading

0 comments on commit f7813d6

Please sign in to comment.