Skip to content

Commit

Permalink
add auxiliary experiments to SQAExperiment (facebook#2658)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2658

SQA storage change for auxiliary_expriments in experiments

Differential Revision: D60927108
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Aug 14, 2024
1 parent a57cd68 commit eb1cf7c
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 2 deletions.
25 changes: 25 additions & 0 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel

from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.batch_trial import AbandonedArm, BatchTrial, GeneratorRunStruct
from ax.core.data import Data
Expand Down Expand Up @@ -147,6 +148,29 @@ def _init_experiment_from_sqa(
# so need to convert it to regular dict.
properties = dict(experiment_sqa.properties or {})
default_data_type = experiment_sqa.default_data_type

auxiliary_experiments_by_purpose = None
if experiment_sqa.auxiliary_experiments_by_purpose:
from ax.storage.sqa_store.load import load_experiment

auxiliary_experiments_by_purpose = {}
aux_exp_name_dict = not_none(
experiment_sqa.auxiliary_experiments_by_purpose
)
for aux_exp_purpose_str, aux_exp_names in aux_exp_name_dict.items():
aux_exp_purpose = [
member
for member in self.config.auxiliary_experiment_purpose_enum
if member.value == aux_exp_purpose_str
][0]
auxiliary_experiments_by_purpose[aux_exp_purpose] = []
for aux_exp_name in aux_exp_names:
auxiliary_experiments_by_purpose[aux_exp_purpose].append(
AuxiliaryExperiment(
experiment=load_experiment(aux_exp_name, config=self.config)
)
)

return Experiment(
name=experiment_sqa.name,
description=experiment_sqa.description,
Expand All @@ -158,6 +182,7 @@ def _init_experiment_from_sqa(
is_test=experiment_sqa.is_test,
properties=properties,
default_data_type=default_data_type,
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
)

def _init_mt_experiment_from_sqa(
Expand Down
10 changes: 10 additions & 0 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
value=experiment.experiment_type, enum=self.config.experiment_type_enum
)

auxiliary_experiments_by_purpose = {}
for (
aux_exp_type_enum,
aux_exps,
) in experiment.auxiliary_experiments_by_purpose.items():
aux_exp_type = aux_exp_type_enum.value
aux_exp_jsons = [aux_exp.experiment.name for aux_exp in aux_exps]
auxiliary_experiments_by_purpose[aux_exp_type] = aux_exp_jsons

properties = experiment._properties
runners = []
if isinstance(experiment, MultiTypeExperiment):
Expand Down Expand Up @@ -213,6 +222,7 @@ def experiment_to_sqa(self, experiment: Experiment) -> SQAExperiment:
properties=properties,
default_trial_type=experiment.default_trial_type,
default_data_type=experiment.default_data_type,
auxiliary_experiments_by_purpose=auxiliary_experiments_by_purpose,
)
return exp_sqa

Expand Down
7 changes: 6 additions & 1 deletion ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
)
from sqlalchemy.orm import backref, relationship


ONLY_ONE_FIELDS = ["experiment_id", "generator_run_id"]


Expand Down Expand Up @@ -510,6 +509,12 @@ class SQAExperiment(Base):
default_trial_type: Optional[str] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
# pyre-fixme[8]: Attribute has type `DataType`; used as `Column[typing.Any]`.
default_data_type: DataType = Column(IntEnum(DataType), nullable=True)
# pyre-fixme[8]: Incompatible attribute type [8]: Attribute
# `auxiliary_experiments_by_purpose` declared in class `SQAExperiment` has
# type `Optional[Dict[str, List[str]]]` but is used as type `Column[typing.Any]`
auxiliary_experiments_by_purpose: Optional[dict[str, list[str]]] = Column(
JSONEncodedTextDict, nullable=True, default={}
)

# relationships
# Trials and experiments are mutable, so the children relationships need
Expand Down
2 changes: 2 additions & 0 deletions ax/storage/sqa_store/sqa_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ax.analysis.analysis import AnalysisCard

from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperimentPurpose
from ax.core.batch_trial import AbandonedArm
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -86,6 +87,7 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
)
experiment_type_enum: Optional[Union[Enum, type[Enum]]] = None
generator_run_type_enum: Optional[Union[Enum, type[Enum]]] = GeneratorRunType
auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose

# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
Expand Down
44 changes: 43 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from datetime import datetime
from enum import Enum
from enum import Enum, unique
from logging import Logger
from typing import Any
from unittest import mock
Expand All @@ -20,7 +20,9 @@
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective
Expand Down Expand Up @@ -208,6 +210,46 @@ def test_ExperimentSaveAndLoad(self) -> None:
loaded_experiment = load_experiment(exp.name)
self.assertEqual(loaded_experiment, exp)

def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
@unique
class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose):
MyAuxExpPurpose = "my_auxiliary_experiment_purpose"

self.config.auxiliary_experiment_purpose_enum = TestAuxiliaryExperimentPurpose

aux_experiment = Experiment(
name="test_aux_exp_in_SQAStoreTest",
search_space=get_search_space(),
optimization_config=get_optimization_config(),
description="test description",
tracking_metrics=[Metric(name="tracking")],
is_test=True,
)
save_experiment(aux_experiment, config=self.config)

experiment_w_aux_exp = Experiment(
name="test_experiment_w_aux_exp_in_SQAStoreTest",
search_space=get_search_space(),
optimization_config=get_optimization_config(),
description="test description",
tracking_metrics=[Metric(name="tracking")],
is_test=True,
auxiliary_experiments_by_purpose={
# pyre-ignore[16]: `AuxiliaryExperimentPurpose` has no attribute
self.config.auxiliary_experiment_purpose_enum.MyAuxExpPurpose: [
AuxiliaryExperiment(experiment=aux_experiment)
]
},
)
self.assertIsNone(experiment_w_aux_exp.db_id)
save_experiment(experiment_w_aux_exp, config=self.config)
self.assertIsNotNone(experiment_w_aux_exp.db_id)
loaded_experiment = load_experiment(
experiment_w_aux_exp.name, config=self.config
)
self.assertEqual(experiment_w_aux_exp, loaded_experiment)
self.assertEqual(len(loaded_experiment.auxiliary_experiments_by_purpose), 1)

def test_saving_an_experiment_with_type_requires_an_enum(self) -> None:
self.experiment.experiment_type = "TEST"
with self.assertRaises(SQAEncodeError):
Expand Down

0 comments on commit eb1cf7c

Please sign in to comment.