Skip to content

Commit

Permalink
Reduce Unit Test Times (Part 3) (#3850)
Browse files Browse the repository at this point in the history
* add coverage report

* define env vars in shared action

* reduce time for longest running tests

* fix broken shared action

* reduce test time

* reducing Pipeline test times

* further reducing test times

* rework Z3 test

* testing new mp.pool and persistent dist envs

* fix import

* reuse distributed environment for tests with lots of param combos

* fix for dist teardown

* fix pickling issue with pool cache

* actually fix pickling problem

* avoid running pool cache stuff on non-distributed tests

* fix issues with nested mp.pool

* fix for nested pools in Pipeline Engine

* re-add params

* update workflows with pytest opts

* implement feedback

* resolve race condition with port selection

* Update tests/unit/common.py

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
mrwyattii and tjruwase authored Jul 12, 2023
1 parent e59f69a commit aef6c65
Show file tree
Hide file tree
Showing 35 changed files with 463 additions and 415 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/amd-mi100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Install pytorch
run: |
pip install torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.1.1
pip install --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.1.1
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -50,7 +50,7 @@ jobs:
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/
pytest $PYTEST_OPTS -n 4 --verbose unit/
pytest $PYTEST_OPTS -m 'sequential' unit/
8 changes: 4 additions & 4 deletions .github/workflows/amd-mi200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.4.2
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down Expand Up @@ -58,7 +58,7 @@ jobs:
# Runs a set of commands using the runners shell
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 --verbose unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/
pytest $PYTEST_OPTS -n 4 --verbose unit/
pytest $PYTEST_OPTS -m 'sequential' unit/
5 changes: 2 additions & 3 deletions .github/workflows/cpu-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ jobs:
run: |
source oneCCL/build/_install/env/setvars.sh
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference' unit/inference/test_inference_config.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -k TestDistAllReduce unit/comm/test_dist.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'inference' unit/inference/test_inference_config.py
TRANSFORMERS_CACHE=~/tmp/transformers_cache/ pytest $PYTEST_OPTS -k TestDistAllReduce unit/comm/test_dist.py
6 changes: 3 additions & 3 deletions .github/workflows/nv-accelerate-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu111
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -41,7 +41,7 @@ jobs:
- name: HF Accelerate tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
git clone https://github.com/huggingface/accelerate
cd accelerate
git rev-parse --short HEAD
Expand All @@ -52,4 +52,4 @@ jobs:
# tmp fix: force newer datasets version
#pip install "datasets>=2.0.0"
pip list
HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose tests/deepspeed
pytest $PYTEST_OPTS --color=yes --durations=0 --verbose tests/deepspeed
5 changes: 2 additions & 3 deletions .github/workflows/nv-h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions python -m pytest -n 4 unit/ --torch_ver="2.0" --cuda_ver="12"
TORCH_EXTENSIONS_DIR=./torch-extensions python -m pytest -m 'sequential' unit/ --torch_ver="2.0" --cuda_ver="12"
python -m pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="2.0" --cuda_ver="12"
python -m pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="2.0" --cuda_ver="12"
15 changes: 10 additions & 5 deletions .github/workflows/nv-inference.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -49,8 +49,13 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6"
TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
coverage run --concurrency=multiprocessing -m pytest $PYTEST_OPTS -m 'seq_inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
coverage run --concurrency=multiprocessing -m pytest $PYTEST_OPTS -m 'inference_ops' unit/ --torch_ver="1.13" --cuda_ver="11.6"
coverage run --concurrency=multiprocessing -m pytest $PYTEST_OPTS --forked -n 4 -m 'inference' unit/ --torch_ver="1.13" --cuda_ver="11.6"
- name: Coverage report
run: |
cd tests
coverage combine
coverage report -m
6 changes: 3 additions & 3 deletions .github/workflows/nv-lightning-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -41,8 +41,8 @@ jobs:
- name: PyTorch Lightning Tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
pip install pytorch-lightning
pip install "protobuf<4.21.0"
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose lightning/
pytest $PYTEST_OPTS lightning/
5 changes: 2 additions & 3 deletions .github/workflows/nv-megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down Expand Up @@ -57,6 +57,5 @@ jobs:
cd Megatron-DeepSpeed
pip install .
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
MEGATRON_CKPT_DIR=/blob/megatron_ckpt/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --verbose ./
pytest $PYTEST_OPTS ./
5 changes: 2 additions & 3 deletions .github/workflows/nv-mii.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip3 install -U --cache-dir /blob/torch_cache torch
pip3 install -U --cache-dir $TORCH_CACHE torch
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand Down Expand Up @@ -54,6 +54,5 @@ jobs:
cd DeepSpeed-MII
pip install .[dev]
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --color=yes --durations=0 --forked --verbose -m "deepspeed" ./
pytest $PYTEST_OPTS --forked -m "deepspeed" ./
5 changes: 2 additions & 3 deletions .github/workflows/nv-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch==1.13.1 torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -45,6 +45,5 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TRANSFORMERS_CACHE=/blob/transformers_cache/ TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.6"
pytest $PYTEST_OPTS --forked -m 'nightly' unit/ --torch_ver="1.13" --cuda_ver="11.6"
5 changes: 2 additions & 3 deletions .github/workflows/nv-torch-latest-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -n 4 unit/ --torch_ver="1.12"
TORCH_EXTENSIONS_DIR=./torch-extensions pytest -m 'sequential' unit/ --torch_ver="1.12"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -n 4 unit/ --torch_ver="1.12"
TRANSFORMERS_CACHE=/tmp/transformers_cache/ pytest $PYTEST_OPTS -m 'sequential' unit/ --torch_ver="1.12"
13 changes: 9 additions & 4 deletions .github/workflows/nv-torch-latest-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
pip install -U --cache-dir $TORCH_CACHE torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -51,7 +51,12 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="2.0" --cuda_ver="11.7"
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/ --torch_ver="2.0" --cuda_ver="11.7"
coverage run --concurrency=multiprocessing -m pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="2.0" --cuda_ver="11.7"
coverage run --concurrency=multiprocessing -m pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="2.0" --cuda_ver="11.7"
- name: Coverage report
run: |
cd tests
coverage combine
coverage report -m
5 changes: 2 additions & 3 deletions .github/workflows/nv-torch-nightly-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/
pytest $PYTEST_OPTS --forked -n 4 unit/
pytest $PYTEST_OPTS --forked -m 'sequential' unit/
6 changes: 3 additions & 3 deletions .github/workflows/nv-torch19-p40.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -U --cache-dir $TORCH_CACHE torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -44,6 +44,6 @@ jobs:
- name: Unit tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11.1"
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11.1"
7 changes: 3 additions & 4 deletions .github/workflows/nv-torch19-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:

- name: Install pytorch
run: |
pip install -U --cache-dir /blob/torch_cache torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
pip install -U --cache-dir $TORCH_CACHE torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -45,7 +45,6 @@ jobs:
- name: Unit tests
run: |
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
cd tests
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11"
TORCH_EXTENSIONS_DIR=./torch-extensions pytest --forked -m 'sequential' unit/ --torch_ver="1.9" --cuda_ver="11"
pytest $PYTEST_OPTS --forked -n 4 unit/ --torch_ver="1.9" --cuda_ver="11"
pytest $PYTEST_OPTS --forked -m 'sequential' unit/ --torch_ver="1.9" --cuda_ver="11"
6 changes: 3 additions & 3 deletions .github/workflows/nv-transformers-v100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
- name: Install pytorch
run: |
# use the same pytorch version as transformers CI
pip install -U --cache-dir /blob/torch_cache torch torchvision torchaudio -f https://download.pytorch.org/whl/torch_stable.html
pip install -U --cache-dir $TORCH_CACHE torch torchvision torchaudio -f https://download.pytorch.org/whl/torch_stable.html
python -c "import torch; print('torch:', torch.__version__, torch)"
python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
Expand All @@ -42,7 +42,7 @@ jobs:
- name: HF transformers tests
run: |
if [[ -d ./torch-extensions ]]; then rm -rf ./torch-extensions; fi
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch
git clone https://github.com/huggingface/transformers
cd transformers
# if needed switch to the last known good SHA until transformers@master is fixed
Expand All @@ -57,4 +57,4 @@ jobs:
# force protobuf version due to issues
pip install "protobuf<4.21.0"
pip list
HF_DATASETS_CACHE=/blob/datasets_cache/ TRANSFORMERS_CACHE=/blob/transformers_cache/ WANDB_DISABLED=true TORCH_EXTENSIONS_DIR=./torch-extensions RUN_SLOW=1 pytest --color=yes --durations=0 --verbose tests/deepspeed
WANDB_DISABLED=true RUN_SLOW=1 pytest $PYTEST_OPTS tests/deepspeed
10 changes: 10 additions & 0 deletions .github/workflows/setup-venv/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ runs:
pip install wheel # required after pip>=23.1
echo PATH=$PATH >> $GITHUB_ENV # Make it so venv is inherited for other steps
shell: bash
- id: set-env-vars
run: |
echo TEST_DATA_DIR=/blob/ >> $GITHUB_ENV
echo TRANSFORMERS_CACHE=/blob/transformers_cache/ >> $GITHUB_ENV
echo TORCH_EXTENSIONS_DIR=./torch-extensions/ >> $GITHUB_ENV
echo TORCH_CACHE=/blob/torch_cache/ >> $GITHUB_ENV
echo HF_DATASETS_CACHE=/blob/datasets_cache/ >> $GITHUB_ENV
echo MEGATRON_CKPT_DIR=/blob/megatron_ckpt/ >> $GITHUB_ENV
echo PYTEST_OPTS="--color=yes --durations=0 --verbose -rF" >> $GITHUB_ENV
shell: bash
- id: print-env
run: |
which python
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
clang-format==16.0.2
coverage
docutils<0.18
future
importlib-metadata>=4
Expand Down
5 changes: 5 additions & 0 deletions tests/.coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# .coveragerc to control coverage.py
[run]
parallel = True
sigterm = True
source = deepspeed
12 changes: 10 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,18 @@ def pytest_runtest_call(item):
item.runtest = lambda: True # Dummy function so test is not run twice


# We allow DistributedTest to reuse distributed environments. When the last
# test for a class is run, we want to make sure those distributed environments
# are destroyed.
def pytest_runtest_teardown(item, nextitem):
if getattr(item.cls, "reuse_dist_env", False) and not nextitem:
dist_test_class = item.cls()
for num_procs, pool in dist_test_class._pool_cache.items():
dist_test_class._close_pool(pool, num_procs, force=True)


@pytest.hookimpl(tryfirst=True)
def pytest_fixture_setup(fixturedef, request):
if getattr(fixturedef.func, "is_dist_fixture", False):
#for val in dir(request):
# print(val.upper(), getattr(request, val), "\n")
dist_fixture_class = fixturedef.func()
dist_fixture_class(request)
20 changes: 19 additions & 1 deletion tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# DeepSpeed Team

import pytest
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -98,7 +99,12 @@ def cifar_trainset(fp16=False):
dist.barrier()
if local_rank != 0:
dist.barrier()
trainset = torchvision.datasets.CIFAR10(root='/blob/cifar10-data', train=True, download=True, transform=transform)

data_root = os.getenv("TEST_DATA_DIR", "/tmp/")
trainset = torchvision.datasets.CIFAR10(root=os.path.join(data_root, "cifar10-data"),
train=True,
download=True,
transform=transform)
if local_rank == 0:
dist.barrier()
return trainset
Expand All @@ -114,6 +120,18 @@ def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True,
trainset = cifar_trainset(fp16=fp16)
config['local_rank'] = dist.get_rank()

# deepspeed_io defaults to creating a dataloader that uses a
# multiprocessing pool. Our tests use pools and we cannot nest pools in
# python. Therefore we're injecting this kwarg to ensure that no pools
# are used in the dataloader.
old_method = deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io

def new_method(*args, **kwargs):
kwargs["num_local_io_workers"] = 0
return old_method(*args, **kwargs)

deepspeed.runtime.engine.DeepSpeedEngine.deepspeed_io = new_method

engine, _, _, _ = deepspeed.initialize(config=config,
model=model,
model_parameters=[p for p in model.parameters()],
Expand Down
Loading

0 comments on commit aef6c65

Please sign in to comment.