Skip to content

Commit

Permalink
update api usecase and fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Aug 21, 2024
1 parent c078f63 commit 400bd29
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,39 @@ Usage example:
$ otx train \
--config ... \
--data.input_size 512
.. tab-set::

.. tab-item:: API 1

.. code-block:: python
from otx.algo.detection.yolox import YOLOXS
from otx.core.data.module import OTXDataModule
from otx.engine import Engine
input_size = (512, 512)
model = YOLOXS(label_info=5, input_size=input_size) # should be tuple[int, int]
datamodule = OTXDataModule(..., input_size=input_size)
engine = Engine(model=model, datamodule=datamodule)
engine.train()
.. tab-item:: API 2

.. code-block:: python
from otx.core.data.module import OTXDataModule
from otx.engine import Engine
datamodule = OTXDataModule(..., input_size=(512, 512))
engine = Engine(model="yolox_s", datamodule=datamodule) # model input size will be aligned with the datamodule input size
engine.train()
.. tab-item:: CLI

.. code-block:: bash
(otx) ...$ otx train ... --data.input_size 512
.. _adaptive-input-size:

Expand All @@ -32,11 +64,30 @@ In "downscale" mode, the input size will either decrease or remain unchanged, en

To activate this feature, use the following command with the desired mode:

.. code-block::
.. tab-set::

$ otx train \
--config ... \
--data.adaptive_input_size "auto | downscale"
.. tab-item:: API

.. code-block:: python
from otx.algo.detection.yolox import YOLOXS
from otx.core.data.module import OTXDataModule
from otx.engine import Engine
datamodule = OTXDataModule(
...
adaptive_input_size="auto", # auto or downscale
input_size_multiplier=YOLOXS.input_size_multiplier, # should set the input_size_multiplier of the model
)
model = YOLOXS(label_info=5, input_size=datamodule.input_size)
engine = Engine(model=model, datamodule=datamodule)
engine.train()
.. tab-item:: CLI

.. code-block:: bash
(otx) ...$ otx train ... --data.adaptive_input_size "auto | downscale"
The adaptive process includes the following steps:

Expand Down
3 changes: 2 additions & 1 deletion src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(
get_model_args: dict[str, Any] = {}
if self._datamodule is not None:
get_model_args["label_info"] = self._datamodule.label_info
get_model_args["input_size"] = self._datamodule.input_size
if (input_size := self._datamodule.input_size) is not None:
get_model_args["input_size"] = (input_size, input_size) if isinstance(input_size, int) else input_size
self._model: OTXModel = (
model if isinstance(model, OTXModel) else self._auto_configurator.get_model(**get_model_args)
)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ def test_model_init(self, tmp_path, mock_datamodule):
assert engine._model.input_size == (1234, 1234)
assert engine._model.label_info.num_classes == 4321

def test_model_init_datamodule_ipt_size_int(self, tmp_path, mock_datamodule):
mock_datamodule.input_size = 1234
data_root = "tests/assets/classification_dataset"
engine = Engine(work_dir=tmp_path, data_root=data_root)

assert engine._model.input_size == (1234, 1234)
assert engine._model.label_info.num_classes == 4321

def test_model_setter(self, fxt_engine, mocker) -> None:
assert isinstance(fxt_engine.model, TVModelForMulticlassCls)
fxt_engine.model = "efficientnet_b0"
Expand Down

0 comments on commit 400bd29

Please sign in to comment.