From eb1cf7c13b6b2de7083176d680943ea0ba93ae1b Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Wed, 14 Aug 2024 12:54:41 -0700 Subject: [PATCH] add auxiliary experiments to SQAExperiment (#2658) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2658 SQA storage change for auxiliary_expriments in experiments Differential Revision: D60927108 --- ax/storage/sqa_store/decoder.py | 25 +++++++++++ ax/storage/sqa_store/encoder.py | 10 +++++ ax/storage/sqa_store/sqa_classes.py | 7 +++- ax/storage/sqa_store/sqa_config.py | 2 + ax/storage/sqa_store/tests/test_sqa_store.py | 44 +++++++++++++++++++- 5 files changed, 86 insertions(+), 2 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index e3788a12280..7a1eb0636fd 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -17,6 +17,7 @@ from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.batch_trial import AbandonedArm, BatchTrial, GeneratorRunStruct from ax.core.data import Data @@ -147,6 +148,29 @@ def _init_experiment_from_sqa( # so need to convert it to regular dict. properties = dict(experiment_sqa.properties or {}) default_data_type = experiment_sqa.default_data_type + + auxiliary_experiments_by_purpose = None + if experiment_sqa.auxiliary_experiments_by_purpose: + from ax.storage.sqa_store.load import load_experiment + + auxiliary_experiments_by_purpose = {} + aux_exp_name_dict = not_none( + experiment_sqa.auxiliary_experiments_by_purpose + ) + for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items(): + aux_exp_purpose = [ + member + for member in self.config.auxiliary_experiment_purpose_enum + if member.value == aux_exp_purpose_str + ][0] + auxiliary_experiments_by_purpose[aux_exp_purpose] = [] + for aux_exp_name in aux_exp_names: + auxiliary_experiments_by_purpose[aux_exp_purpose].append( + AuxiliaryExperiment( + experiment=load_experiment(aux_exp_name, config=self.config) + ) + ) + return Experiment( name=experiment_sqa.name, description=experiment_sqa.description, @@ -158,6 +182,7 @@ def _init_experiment_from_sqa( is_test=experiment_sqa.is_test, properties=properties, default_data_type=default_data_type, + auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, ) def _init_mt_experiment_from_sqa( diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 3102b9e271c..fe4a920ecc6 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -174,6 +174,15 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: value=experiment.experiment_type, enum=self.config.experiment_type_enum ) + auxiliary_experiments_by_purpose = {} + for ( + aux_exp_type_enum, + aux_exps, + ) in experiment.auxiliary_experiments_by_purpose.items(): + aux_exp_type = aux_exp_type_enum.value + aux_exp_jsons = [aux_exp.experiment.name for aux_exp in aux_exps] + auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons + properties = experiment._properties runners = [] if isinstance(experiment, MultiTypeExperiment): @@ -213,6 +222,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment: properties=properties, default_trial_type=experiment.default_trial_type, default_data_type=experiment.default_data_type, + auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose, ) return exp_sqa diff --git a/ax/storage/sqa_store/sqa_classes.py b/ax/storage/sqa_store/sqa_classes.py index 544000b8669..9c5decc8ee1 100644 --- a/ax/storage/sqa_store/sqa_classes.py +++ b/ax/storage/sqa_store/sqa_classes.py @@ -49,7 +49,6 @@ ) from sqlalchemy.orm import backref, relationship - ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"] @@ -510,6 +509,12 @@ class SQAExperiment(Base): default_trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH)) # pyre-fixme[8]: Attribute has type `DataType`; used as `Column[typing.Any]`. default_data_type: DataType = Column(IntEnum(DataType), nullable=True) + # pyre-fixme[8]: Incompatible attribute type [8]: Attribute + # `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has + # type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]` + auxiliary_experiments_by_purpose: Optional[dict[str, list[str]]] = Column( + JSONEncodedTextDict, nullable=True, default={} + ) # relationships # Trials and experiments are mutable, so the children relationships need diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index e256d03f499..f66f4db6dbd 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -13,6 +13,7 @@ from ax.analysis.analysis import AnalysisCard from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperimentPurpose from ax.core.batch_trial import AbandonedArm from ax.core.data import Data from ax.core.experiment import Experiment @@ -86,6 +87,7 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: ) experiment_type_enum: Optional[Union[Enum, type[Enum]]] = None generator_run_type_enum: Optional[Union[Enum, type[Enum]]] = GeneratorRunType + auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose # pyre-fixme[4]: Attribute annotation cannot contain `Any`. # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index daccf4df0d9..efb37c2e13f 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -8,7 +8,7 @@ import logging from datetime import datetime -from enum import Enum +from enum import Enum, unique from logging import Logger from typing import Any from unittest import mock @@ -20,7 +20,9 @@ from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard from ax.core.arm import Arm +from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.batch_trial import BatchTrial, LifecycleStage +from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric from ax.core.objective import MultiObjective, Objective @@ -208,6 +210,46 @@ def test_ExperimentSaveAndLoad(self) -> None: loaded_experiment = load_experiment(exp.name) self.assertEqual(loaded_experiment, exp) + def test_saving_and_loading_experiment_with_aux_exp(self) -> None: + @unique + class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose): + MyAuxExpPurpose = "my_auxiliary_experiment_purpose" + + self.config.auxiliary_experiment_purpose_enum = TestAuxiliaryExperimentPurpose + + aux_experiment = Experiment( + name="test_aux_exp_in_SQAStoreTest", + search_space=get_search_space(), + optimization_config=get_optimization_config(), + description="test description", + tracking_metrics=[Metric(name="tracking")], + is_test=True, + ) + save_experiment(aux_experiment, config=self.config) + + experiment_w_aux_exp = Experiment( + name="test_experiment_w_aux_exp_in_SQAStoreTest", + search_space=get_search_space(), + optimization_config=get_optimization_config(), + description="test description", + tracking_metrics=[Metric(name="tracking")], + is_test=True, + auxiliary_experiments_by_purpose={ + # pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute + self.config.auxiliary_experiment_purpose_enum.MyAuxExpPurpose: [ + AuxiliaryExperiment(experiment=aux_experiment) + ] + }, + ) + self.assertIsNone(experiment_w_aux_exp.db_id) + save_experiment(experiment_w_aux_exp, config=self.config) + self.assertIsNotNone(experiment_w_aux_exp.db_id) + loaded_experiment = load_experiment( + experiment_w_aux_exp.name, config=self.config + ) + self.assertEqual(experiment_w_aux_exp, loaded_experiment) + self.assertEqual(len(loaded_experiment.auxiliary_experiments_by_purpose), 1) + def test_saving_an_experiment_with_type_requires_an_enum(self) -> None: self.experiment.experiment_type = "TEST" with self.assertRaises(SQAEncodeError):