Skip to content

Commit

Permalink
Fix CPU training issue on non-CUDA system (#2655)
Browse files Browse the repository at this point in the history
Fix bug that auto adaptive batch size raises an error if CUDA isn't available (#2410)

---------
Co-authored-by: Sungman Cho <sungman.cho@intel.com>
Co-authored-by: Eunwoo Shin <eunwoo.shin@intel.com>
  • Loading branch information
goodsong81 authored Nov 21, 2023
1 parent a0780a8 commit aceebda
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/otx/algorithms/common/adapters/mmcv/utils/automatic_bs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Callable, Dict, List

import numpy as np
from torch.cuda import is_available as cuda_available

from otx.algorithms.common.adapters.torch.utils import BsSearchAlgo
from otx.algorithms.common.utils.logger import get_logger
Expand Down Expand Up @@ -53,6 +54,10 @@ def adapt_batch_size(train_func: Callable, cfg, datasets: List, validate: bool =
not_increase (bool) : Whether adapting batch size to larger value than default value or not.
"""

if not cuda_available():
logger.warning("Skip Auto-adaptive batch size: CUDA should be available, but it isn't.")
return

def train_func_single_iter(batch_size):
copied_cfg = deepcopy(cfg)
_set_batch_size(copied_cfg, batch_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ def test_adapt_batch_size(
assert len(mock_train_func.call_args_list[0].kwargs["cfg"].custom_hooks) == 1


def test_adapt_batch_size_no_gpu(mocker, common_cfg, mock_dataset):
# prepare
mock_train_func = mocker.MagicMock()
mock_config = set_mock_cfg_not_action(common_cfg)
mocker.patch.object(automatic_bs, "cuda_available", return_value=False)

# execute
adapt_batch_size(mock_train_func, mock_config, mock_dataset, False, True)

# check train function ins't called.
mock_train_func.assert_not_called()


class TestSubDataset:
@pytest.fixture(autouse=True)
def set_up(self, mocker):
Expand Down

0 comments on commit aceebda

Please sign in to comment.