Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Action task supports multi GPU training #2057

Merged
merged 8 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion otx/algorithms/action/adapters/mmaction/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def configure(

recipe_cfg.work_dir = self._output_path
recipe_cfg.resume = self._resume
recipe_cfg.distributed = False
recipe_cfg.omnisource = False

self._configure_device(recipe_cfg, training)

if data_cfg is not None:
recipe_cfg.merge_from_dict(data_cfg)

Expand Down Expand Up @@ -196,6 +197,40 @@ def configure(
self._config = recipe_cfg
return recipe_cfg

def _configure_device(self, cfg, training):
sungmanc marked this conversation as resolved.
Show resolved Hide resolved
"""Setting device for training and inference."""
cfg.distributed = False
if torch.distributed.is_initialized():
cfg.gpu_ids = [int(os.environ["LOCAL_RANK"])]
if training: # TODO multi GPU is available only in training. Evaluation needs to be supported later.
cfg.distributed = True
self.configure_distributed(cfg)
elif "gpu_ids" not in cfg:
gpu_ids = os.environ.get("CUDA_VISIBLE_DEVICES")
logger.info(f"CUDA_VISIBLE_DEVICES = {gpu_ids}")
if gpu_ids is not None:
cfg.gpu_ids = range(len(gpu_ids.split(",")))
else:
cfg.gpu_ids = range(1)

# consider "cuda" and "cpu" device only
if not torch.cuda.is_available():
cfg.device = "cpu"
cfg.gpu_ids = range(-1, 0)
else:
cfg.device = "cuda"

@staticmethod
def configure_distributed(cfg):
"""Patching for distributed training."""
if hasattr(cfg, "dist_params") and cfg.dist_params.get("linear_scale_lr", False):
new_lr = len(cfg.gpu_ids) * cfg.optimizer.lr
logger.info(
f"enabled linear scaling rule to the learning rate. \
changed LR from {cfg.optimizer.lr} to {new_lr}"
)
cfg.optimizer.lr = new_lr

# pylint: disable=too-many-branches, too-many-statements
def _train_model(
self,
Expand Down
4 changes: 4 additions & 0 deletions otx/algorithms/action/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from otx.api.usecases.evaluation.metrics_helper import MetricsHelper
from otx.api.usecases.tasks.interfaces.export_interface import ExportType
from otx.api.utils.vis_utils import get_actmap
from otx.cli.utils.multi_gpu import is_multigpu_child_process

logger = get_logger()

Expand Down Expand Up @@ -430,6 +431,9 @@ def _generate_training_metrics(learning_curves, scores, metric_name="mAP") -> It

def save_model(self, output_model: ModelEntity):
"""Save best model weights in ActionTrainTask."""
if is_multigpu_child_process():
return

logger.info("called save_model")
buffer = io.BytesIO()
hyperparams_str = ids_to_strings(cfg_helper.convert(self._hyperparams, dict, enum_to_str=True))
Expand Down
6 changes: 5 additions & 1 deletion otx/algorithms/common/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# and limitations under the License.

import io
import logging
import os
import shutil
import tempfile
Expand Down Expand Up @@ -136,7 +137,10 @@ def _setup_multigpu_training():
if not dist.is_initialized():
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
dist.init_process_group(backend="nccl", init_method="env://", timeout=timedelta(seconds=30))
logger.info(f"Dist info: rank {dist.get_rank()} / {dist.get_world_size()} world_size")
rank = dist.get_rank()
logger.info(f"Dist info: rank {rank} / {dist.get_world_size()} world_size")
if rank != 0:
logging.disable(logging.WARNING)

def _get_tmp_dir(self):
self._work_dir_is_temp = True
Expand Down
5 changes: 1 addition & 4 deletions otx/cli/tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

from otx.api.entities.inference_parameters import InferenceParameters
from otx.api.entities.model import ModelEntity
from otx.api.entities.model_template import TaskType
from otx.api.entities.resultset import ResultSetEntity
from otx.api.entities.subset import Subset
from otx.api.entities.task_environment import TaskEnvironment
Expand Down Expand Up @@ -229,9 +228,7 @@ def train(exit_stack: Optional[ExitStack] = None): # pylint: disable=too-many-b

if args.gpus:
multigpu_manager = MultiGPUManager(train, args.gpus, args.rdzv_endpoint, args.base_rank, args.world_size)
if template.task_type in (TaskType.ACTION_CLASSIFICATION, TaskType.ACTION_DETECTION):
print("Multi-GPU training for action tasks isn't supported yet. A single GPU will be used for a training.")
elif (
if (
multigpu_manager.is_available()
and not template.task_type.is_anomaly # anomaly tasks don't use this way for multi-GPU training
):
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/cli/action/test_action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from copy import deepcopy

import pytest
import torch

from otx.api.entities.model_template import parse_model_template
from otx.cli.registry import Registry
Expand All @@ -29,6 +30,7 @@

otx_dir = os.getcwd()

MULTI_GPU_UNAVAILABLE = torch.cuda.device_count() <= 1
TT_STABILITY_TESTS = os.environ.get("TT_STABILITY_TESTS", False)
if TT_STABILITY_TESTS:
default_template = parse_model_template(
Expand Down Expand Up @@ -77,3 +79,13 @@ def test_otx_train_auto_decrease_batch_size(self, template, tmp_dir_path):
decrease_bs_args["train_params"].extend(["--learning_parameters.auto_decrease_batch_size", "true"])
tmp_dir_path = tmp_dir_path / "action_cls_auto_decrease_batch_size"
otx_train_testing(template, tmp_dir_path, otx_dir, decrease_bs_args)

@e2e_pytest_component
@pytest.mark.skipif(MULTI_GPU_UNAVAILABLE, reason="The number of gpu is insufficient")
@pytest.mark.skipif(TT_STABILITY_TESTS, reason="This is TT_STABILITY_TESTS")
@pytest.mark.parametrize("template", templates, ids=templates_ids)
def test_otx_multi_gpu_train(self, template, tmp_dir_path):
tmp_dir_path = tmp_dir_path / "action_cls/test_multi_gpu"
args1 = deepcopy(args)
args1["--gpus"] = "0,1"
otx_train_testing(template, tmp_dir_path, otx_dir, args1)