Skip to content

Commit

Permalink
AnalysisCard load/save methods (#2645)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2645

Add logic to load/save `AnalysisCard`

Reviewed By: mpolson64

Differential Revision: D60321245
  • Loading branch information
Cesar-Cardoso authored and facebook-github-bot committed Aug 8, 2024
1 parent 98d7d9f commit acd02bd
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 10 deletions.
34 changes: 31 additions & 3 deletions ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from math import ceil
from typing import Any, cast, Dict, List, Optional, Type

from ax.analysis.analysis import AnalysisCard

from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand All @@ -22,6 +24,7 @@
get_query_options_to_defer_large_model_cols,
)
from ax.storage.sqa_store.sqa_classes import (
SQAAnalysisCard,
SQAExperiment,
SQAGenerationStrategy,
SQAGeneratorRun,
Expand Down Expand Up @@ -63,7 +66,7 @@ def load_experiment(
of metrics, this option converts the loaded metrics into a base
metric avoiding conversion related to custom properties of the metric.
"""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_experiment(
experiment_name=experiment_name,
Expand Down Expand Up @@ -363,7 +366,7 @@ def load_generation_strategy_by_experiment_name(
"""Finds a generation strategy attached to an experiment specified by a name
and restores it from its corresponding SQA object.
"""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_generation_strategy_by_experiment_name(
experiment_name=experiment_name,
Expand All @@ -381,7 +384,7 @@ def load_generation_strategy_by_id(
reduced_state: bool = False,
) -> GenerationStrategy:
"""Finds a generation strategy stored by a given ID and restores it."""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
return _load_generation_strategy_by_id(
gs_id=gs_id, decoder=decoder, experiment=experiment, reduced_state=reduced_state
Expand Down Expand Up @@ -584,3 +587,28 @@ def _get_generation_strategy_sqa_immutable_opt_config_and_search_space(
lazyload("generator_runs.metrics"),
],
)


def load_analysis_cards_by_experiment_name(
experiment_name: str,
config: Optional[SQAConfig] = None,
) -> List[AnalysisCard]:
"""Loads analysis cards for an experiment."""
config = SQAConfig() if config is None else config
decoder = Decoder(config=config)
analysis_card_sqa_class: SQAAnalysisCard = cast(
SQAAnalysisCard, decoder.config.class_to_sqa_class[AnalysisCard]
)
exp_sqa_class: SQAExperiment = cast(
SQAExperiment, decoder.config.class_to_sqa_class[Experiment]
)
with session_scope() as session:
analysis_cards_sqa = (
session.query(analysis_card_sqa_class)
.join(exp_sqa_class.analysis_cards)
.filter(exp_sqa_class.name == experiment_name)
)
return [
decoder.analysis_card_from_sqa(analysis_card_sqa=analysis_card_sqa)
for analysis_card_sqa in analysis_cards_sqa
]
63 changes: 56 additions & 7 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
# pyre-strict

import os
from datetime import datetime

from logging import Logger
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union

from ax.analysis.analysis import AnalysisCard

from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -47,7 +50,7 @@ def save_experiment(experiment: Experiment, config: Optional[SQAConfig] = None)
raise ValueError("Can only save instances of Experiment")
if not experiment.has_name:
raise ValueError("Experiment name must be set prior to saving.")
config = config or SQAConfig()
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)
_save_experiment(experiment=experiment, encoder=encoder, decoder=decoder)
Expand Down Expand Up @@ -107,7 +110,7 @@ def save_generation_strategy(
The ID of the saved generation strategy.
"""
# Start up SQA encoder.
config = config or SQAConfig()
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)

Expand Down Expand Up @@ -150,7 +153,7 @@ def save_or_update_trial(
) -> None:
"""Add new trial to the experiment, or update if already exists
(using default SQAConfig)."""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)
_save_or_update_trial(
Expand Down Expand Up @@ -189,7 +192,7 @@ def save_or_update_trials(
will also be added to the experiment, but existing data objects in the
database will *not* be updated or removed.
"""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)
_save_or_update_trials(
Expand Down Expand Up @@ -308,7 +311,7 @@ def update_generation_strategy(
) -> None:
"""Update generation strategy's current step and attach generator runs
(using default SQAConfig)."""
config = config or SQAConfig()
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)
_update_generation_strategy(
Expand Down Expand Up @@ -450,7 +453,7 @@ def update_properties_on_experiment(
experiment_with_updated_properties: Experiment,
config: Optional[SQAConfig] = None,
) -> None:
config = config or SQAConfig()
config = SQAConfig() if config is None else config
exp_sqa_class = config.class_to_sqa_class[Experiment]

exp_id = experiment_with_updated_properties.db_id
Expand All @@ -469,7 +472,7 @@ def update_properties_on_trial(
trial_with_updated_properties: BaseTrial,
config: Optional[SQAConfig] = None,
) -> None:
config = config or SQAConfig()
config = SQAConfig() if config is None else config
trial_sqa_class = config.class_to_sqa_class[Trial]

trial_id = trial_with_updated_properties.db_id
Expand All @@ -484,6 +487,52 @@ def update_properties_on_trial(
)


def save_analysis_cards(
analysis_cards: List[AnalysisCard],
experiment: Experiment,
config: Optional[SQAConfig] = None,
) -> None:
# Start up SQA encoder.
config = SQAConfig() if config is None else config
encoder = Encoder(config=config)
decoder = Decoder(config=config)
timestamp = datetime.utcnow()
_save_analysis_cards(
analysis_cards=analysis_cards,
experiment=experiment,
timestamp=timestamp,
encoder=encoder,
decoder=decoder,
)


def _save_analysis_cards(
analysis_cards: List[AnalysisCard],
experiment: Experiment,
timestamp: datetime,
encoder: Encoder,
decoder: Decoder,
) -> None:
if any(analysis_card.db_id is not None for analysis_card in analysis_cards):
raise ValueError("Analysis cards cannot be updated.")
if experiment.db_id is None:
raise ValueError(
f"Experiment {experiment.name} should be saved before analysis cards."
)
_bulk_merge_into_session(
objs=analysis_cards,
encode_func=encoder.analysis_card_to_sqa,
decode_func=decoder.analysis_card_from_sqa,
encode_args_list=[
{
"experiment_id": experiment.db_id,
"timestamp": timestamp,
}
for _analysis_card in analysis_cards
],
)


def _merge_into_session(
obj: Base,
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
Expand Down
64 changes: 64 additions & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
from unittest import mock
from unittest.mock import MagicMock, Mock, patch

import pandas as pd

from ax.analysis.analysis import AnalysisCard, AnalysisCardLevel
from ax.analysis.markdown.markdown_analysis import MarkdownAnalysisCard
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysisCard
from ax.core.arm import Arm
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.generator_run import GeneratorRun
Expand Down Expand Up @@ -48,12 +53,14 @@
_get_experiment_immutable_opt_config_and_search_space,
_get_experiment_sqa_immutable_opt_config_and_search_space,
_get_generation_strategy_sqa_immutable_opt_config_and_search_space,
load_analysis_cards_by_experiment_name,
load_experiment,
load_generation_strategy_by_experiment_name,
load_generation_strategy_by_id,
)
from ax.storage.sqa_store.reduced_state import GR_LARGE_MODEL_ATTRS
from ax.storage.sqa_store.save import (
save_analysis_cards,
save_experiment,
save_generation_strategy,
save_or_update_trial,
Expand Down Expand Up @@ -114,6 +121,7 @@
get_synthetic_runner,
)
from ax.utils.testing.modeling_stubs import get_generation_strategy
from plotly import graph_objects as go, io as pio

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -1926,3 +1934,59 @@ def test_CreateAllTablesException(self) -> None:
engine.dialect.default_schema_name = "ax"
with self.assertRaises(ValueError):
create_all_tables(engine)

def test_AnalysisCard(self) -> None:
test_df = pd.DataFrame(
columns=["a", "b"],
data=[
[1, 2],
[3, 4],
],
)
base_analysis_card = AnalysisCard(
name="test_base_analysis_card",
title="test_title",
subtitle="test_subtitle",
level=AnalysisCardLevel.DEBUG,
df=test_df,
blob="test blob",
)
markdown_analysis_card = MarkdownAnalysisCard(
name="test_markdown_analysis_card",
title="test_title",
subtitle="test_subtitle",
level=AnalysisCardLevel.DEBUG,
df=test_df,
blob="This is some **really cool** markdown",
)
plotly_analysis_card = PlotlyAnalysisCard(
name="test_plotly_analysis_card",
title="test_title",
subtitle="test_subtitle",
level=AnalysisCardLevel.DEBUG,
df=test_df,
blob=pio.to_json(go.Figure()),
)
with self.subTest("test_save_analysis_cards"):
save_experiment(self.experiment)
save_analysis_cards(
[base_analysis_card, markdown_analysis_card, plotly_analysis_card],
self.experiment,
)
with self.subTest("test_load_analysis_cards"):
loaded_analysis_cards = load_analysis_cards_by_experiment_name(
self.experiment.name
)
self.assertEqual(len(loaded_analysis_cards), 3)
self.assertEqual(
loaded_analysis_cards[0].blob,
base_analysis_card.blob,
)
self.assertEqual(
loaded_analysis_cards[1].blob,
markdown_analysis_card.blob,
)
self.assertEqual(
loaded_analysis_cards[2].blob,
plotly_analysis_card.blob,
)

0 comments on commit acd02bd

Please sign in to comment.