diff --git a/tests/perf/benchmark.py b/tests/perf/benchmark.py index d444059385f..63b061291e5 100644 --- a/tests/perf/benchmark.py +++ b/tests/perf/benchmark.py @@ -70,6 +70,7 @@ class Dataset: group: str num_repeat: int = 1 extra_overrides: dict | None = None + unlabeled_data_path: Path | None = None @dataclass class Criterion: @@ -192,6 +193,13 @@ def run( "--engine.device", self.accelerator, ] + + # Add unlabeled data path if exists + if dataset.unlabeled_data_path is not None: + command.extend( + ["--data.config.unlabeled_subset.data_root", str(self.data_root / dataset.unlabeled_data_path)], + ) + for key, value in dataset.extra_overrides.get("train", {}).items(): command.append(f"--{key}") command.append(str(value)) diff --git a/tests/perf/test_classification.py b/tests/perf/test_classification.py index e22806883f6..f94e981334b 100644 --- a/tests/perf/test_classification.py +++ b/tests/perf/test_classification.py @@ -22,6 +22,9 @@ class TestPerfSingleLabelClassification(PerfTestBase): Benchmark.Model(task="classification/multi_class_cls", name="mobilenet_v3_large", category="accuracy"), Benchmark.Model(task="classification/multi_class_cls", name="deit_tiny", category="other"), Benchmark.Model(task="classification/multi_class_cls", name="dino_v2", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_efficientnet_b3", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_efficientnet_v2_l", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_mobilenet_v3_small", category="other"), ] DATASET_TEST_CASES = [ @@ -258,3 +261,121 @@ def test_perf( criteria=self.BENCHMARK_CRITERIA, resume_from=fxt_resume_from, ) + + +class TestPerfSemiSLMultiClass(PerfTestBase): + """Benchmark single-label classification for Semi-SL task.""" + + MODEL_TEST_CASES = [ # noqa: RUF012 + Benchmark.Model(task="classification/multi_class_cls", name="efficientnet_b0_semisl", category="speed"), + Benchmark.Model(task="classification/multi_class_cls", name="mobilenet_v3_large_semisl", category="speed"), + Benchmark.Model(task="classification/multi_class_cls", name="efficientnet_v2_semisl", category="accuracy"), + Benchmark.Model(task="classification/multi_class_cls", name="deit_tiny_semisl", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="dino_v2_semisl", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_efficientnet_b3_semisl", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_efficientnet_v2_l_semisl", category="other"), + Benchmark.Model(task="classification/multi_class_cls", name="tv_mobilenet_v3_small_semisl", category="other"), + ] + + DATASET_TEST_CASES = ( + [ + Benchmark.Dataset( + name=f"cifar10@{num_label}_{idx}", + path=Path(f"multiclass_classification/semi-sl/cifar10@{num_label}_{idx}/supervised"), + group="cifar10", + num_repeat=1, + unlabeled_data_path=Path(f"multiclass_classification/semi-sl/cifar10@{num_label}_{idx}/unlabel_data"), + extra_overrides={ + "train": { + "data.config.train_subset.subset_name": "train_data", + "data.config.val_subset.subset_name": "val_data", + "data.config.test_subset.subset_name": "val_data", + "deterministic": "True", + }, + }, + ) + for idx in (0, 1, 2) + for num_label in (4, 10, 25) + ] + + [ + Benchmark.Dataset( + name=f"svhn@{num_label}_{idx}", + path=Path(f"multiclass_classification/semi-sl/svhn@{num_label}_{idx}/supervised"), + group="svhn", + num_repeat=1, + unlabeled_data_path=Path(f"multiclass_classification/semi-sl/svhn@{num_label}_{idx}/unlabel_data"), + extra_overrides={ + "train": { + "data.config.train_subset.subset_name": "train_data", + "data.config.val_subset.subset_name": "val_data", + "data.config.test_subset.subset_name": "val_data", + "deterministic": "True", + }, + }, + ) + for idx in (0, 1, 2) + for num_label in (4, 10, 25) + ] + + [ + Benchmark.Dataset( + name=f"fmnist@{num_label}_{idx}", + path=Path(f"multiclass_classification/semi-sl/fmnist@{num_label}_{idx}/supervised"), + group="fmnist", + num_repeat=1, + unlabeled_data_path=Path(f"multiclass_classification/semi-sl/fmnist@{num_label}_{idx}/unlabel_data"), + extra_overrides={ + "train": { + "data.config.train_subset.subset_name": "train_data", + "data.config.val_subset.subset_name": "val_data", + "data.config.test_subset.subset_name": "val_data", + "deterministic": "True", + }, + }, + ) + for idx in (0, 1, 2) + for num_label in (4, 10, 25) + ] + ) + + BENCHMARK_CRITERIA = [ # noqa: RUF012 + Benchmark.Criterion(name="train/epoch", summary="max", compare="<", margin=0.1), + Benchmark.Criterion(name="train/e2e_time", summary="max", compare="<", margin=0.1), + Benchmark.Criterion(name="val/accuracy", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="test/accuracy", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="export/accuracy", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="optimize/accuracy", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="train/iter_time", summary="mean", compare="<", margin=0.1), + Benchmark.Criterion(name="test/iter_time", summary="mean", compare="<", margin=0.1), + Benchmark.Criterion(name="export/iter_time", summary="mean", compare="<", margin=0.1), + Benchmark.Criterion(name="optimize/iter_time", summary="mean", compare="<", margin=0.1), + Benchmark.Criterion(name="test(train)/e2e_time", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="test(export)/e2e_time", summary="max", compare=">", margin=0.1), + Benchmark.Criterion(name="test(optimize)/e2e_time", summary="max", compare=">", margin=0.1), + ] + + @pytest.mark.parametrize( + "fxt_model", + MODEL_TEST_CASES, + ids=lambda model: model.name, + indirect=True, + ) + @pytest.mark.parametrize( + "fxt_dataset", + DATASET_TEST_CASES, + ids=lambda dataset: dataset.name, + indirect=True, + ) + def test_perf( + self, + fxt_model: Benchmark.Model, + fxt_dataset: Benchmark.Dataset, + fxt_benchmark: Benchmark, + fxt_resume_from: Path | None, + ): + self._test_perf( + model=fxt_model, + dataset=fxt_dataset, + benchmark=fxt_benchmark, + criteria=self.BENCHMARK_CRITERIA, + resume_from=fxt_resume_from, + )