Skip to content

Commit

Permalink
Re-sync with internal repository (#1675)
Browse files Browse the repository at this point in the history
Co-authored-by: Facebook Community Bot <6422482+facebook-github-bot@users.noreply.github.com>
  • Loading branch information
facebook-github-bot and facebook-github-bot authored Jun 21, 2023
1 parent 637847e commit 139e9a8
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 60 deletions.
31 changes: 20 additions & 11 deletions ax/core/batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@
from ax.core.data import Data
from ax.core.generator_run import ArmWeight, GeneratorRun, GeneratorRunType
from ax.core.trial import immutable_once_run
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.core.types import (
TCandidateMetadata,
TEvaluationOutcome,
validate_evaluation_outcome,
)
from ax.exceptions.core import AxError, UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.docutils import copy_doc
from ax.utils.common.equality import datetime_equals, equality_typechecker
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast, checked_cast_complex, not_none
from ax.utils.common.typeutils import checked_cast, not_none


logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -568,23 +572,28 @@ def attach_batch_trial_data(
complete_trial: Whether to mark trial as complete after
attaching data. Defaults to False.
"""
# Validate type of raw_data
if not isinstance(raw_data, dict):
raise ValueError(BATCH_TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE)

for key, value in raw_data.items():
if not isinstance(key, str):
raise ValueError(BATCH_TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE)

try:
validate_evaluation_outcome(outcome=value)
except TypeError:
raise ValueError(BATCH_TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE)

# Format the data to save.
raw_data_by_arm = checked_cast_complex(
Dict[str, TEvaluationOutcome],
raw_data,
message=BATCH_TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE,
)
not_trial_arm_names = set(raw_data_by_arm.keys()) - set(
self.arms_by_name.keys()
)
not_trial_arm_names = set(raw_data.keys()) - set(self.arms_by_name.keys())
if not_trial_arm_names:
raise UserInputError( # pragma: no cover
f"Arms {not_trial_arm_names} are not part of trial #{self.index}."
)

evaluations, data = self._make_evaluations_and_data(
raw_data=raw_data_by_arm, metadata=metadata, sample_sizes=sample_sizes
raw_data=raw_data, metadata=metadata, sample_sizes=sample_sizes
)
self._validate_batch_trial_data(data=data)

Expand Down
45 changes: 44 additions & 1 deletion ax/core/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.core.types import merge_model_predict
from ax.core.types import (
merge_model_predict,
validate_evaluation_outcome,
validate_floatlike,
validate_map_dict,
validate_param_value,
validate_parameterization,
validate_single_metric_data,
validate_trial_evaluation,
)
from ax.utils.common.testutils import TestCase


Expand Down Expand Up @@ -40,3 +49,37 @@ def testMergeModelPredictFail(self) -> None:
cov_append = {"m1": {"m1": [0.0], "m2": [0.0]}}
with self.assertRaises(ValueError):
merge_model_predict(self.predict, (mu_append, cov_append))

def testValidate(self) -> None:
trial_evaluation = {"foo": 0.0}
trial_evaluation_with_noise = {"foo": (0.0, 0.0)}
fidelity_trial_evaluation = [({"a": 0.0}, trial_evaluation)]
map_trial_evaluation = [({"a": 0.0}, trial_evaluation)]

validate_evaluation_outcome(outcome=trial_evaluation) # pyre-ignore[6]
validate_evaluation_outcome(
outcome=trial_evaluation_with_noise # pyre-ignore[6]
)
validate_evaluation_outcome(outcome=fidelity_trial_evaluation) # pyre-ignore[6]
validate_evaluation_outcome(outcome=map_trial_evaluation) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_floatlike(floatlike="foo") # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_single_metric_data(data=(0, 1, 2)) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_trial_evaluation(evaluation={0: 0}) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_param_value(param_value=[]) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_parameterization(parameterization={0: 0}) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_map_dict(map_dict={0: 0}) # pyre-ignore[6]

with self.assertRaises(TypeError):
validate_map_dict(map_dict={"foo": []}) # pyre-ignore[6]
20 changes: 11 additions & 9 deletions ax/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
from ax.core.base_trial import BaseTrial, immutable_once_run
from ax.core.data import Data
from ax.core.generator_run import GeneratorRun, GeneratorRunType
from ax.core.types import TCandidateMetadata, TEvaluationOutcome
from ax.core.types import (
TCandidateMetadata,
TEvaluationOutcome,
validate_evaluation_outcome,
)
from ax.utils.common.docutils import copy_doc
from ax.utils.common.logger import _round_floats_for_logging, get_logger
from ax.utils.common.typeutils import checked_cast_complex, not_none
from ax.utils.common.typeutils import not_none

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -294,13 +298,11 @@ def update_trial_data(
sample_sizes = {not_none(self.arm).name: sample_size} if sample_size else {}

arm_name = not_none(self.arm).name
raw_data_by_arm = {
arm_name: checked_cast_complex(
TEvaluationOutcome,
raw_data,
message=TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE,
)
}
try:
validate_evaluation_outcome(outcome=raw_data)
except Exception:
raise ValueError(TRIAL_RAW_DATA_FORMAT_ERROR_MESSAGE)
raw_data_by_arm = {arm_name: raw_data}
not_trial_arm_names = set(raw_data_by_arm.keys()) - set(
self.arms_by_name.keys()
)
Expand Down
102 changes: 101 additions & 1 deletion ax/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
# ( { metric -> mean }, { metric -> { other_metric -> covariance } } ).
TModelPredictArm = Tuple[Dict[str, float], Optional[Dict[str, Dict[str, float]]]]

FloatLike = Union[float, np.floating, np.integer]
FloatLike = Union[int, float, np.floating, np.integer]
SingleMetricDataTuple = Tuple[FloatLike, Optional[FloatLike]]
SingleMetricData = Union[FloatLike, Tuple[FloatLike, Optional[FloatLike]]]
# 1-arm `Trial` evaluation data: {metric_name -> (mean, standard error)}}.
Expand Down Expand Up @@ -110,3 +110,103 @@ def merge_model_predict(
cov_values + cov_append[metric_name][co_metric_name]
)
return mu, cov


def validate_floatlike(floatlike: FloatLike) -> None:
if not (
isinstance(floatlike, float)
or isinstance(floatlike, int)
or isinstance(floatlike, np.floating)
or isinstance(floatlike, np.integer)
):
raise TypeError(f"Expected FloatLike, found {floatlike}")


def validate_single_metric_data(data: SingleMetricData) -> None:
if isinstance(data, tuple):
if len(data) != 2:
raise TypeError(
f"Tuple-valued SingleMetricData must have len == 2, found {data}"
)

mean, sem = data
validate_floatlike(floatlike=mean)

if sem is not None:
validate_floatlike(floatlike=sem)

else:
validate_floatlike(floatlike=data)


def validate_trial_evaluation(evaluation: TTrialEvaluation) -> None:
for key, value in evaluation.items():
if type(key) != str:
raise TypeError(f"Keys must be strings in TTrialEvaluation, found {key}.")

validate_single_metric_data(data=value)


def validate_param_value(param_value: TParamValue) -> None:
if not (
isinstance(param_value, str)
or isinstance(param_value, bool)
or isinstance(param_value, float)
or isinstance(param_value, int)
or param_value is None
):
raise TypeError(f"Expected None, bool, float, int, or str, found {param_value}")


def validate_parameterization(parameterization: TParameterization) -> None:
for key, value in parameterization.items():
if type(key) != str:
raise TypeError(f"Keys must be strings in TParameterization, found {key}.")

validate_param_value(param_value=value)


def validate_map_dict(map_dict: TMapDict) -> None:
for key, value in map_dict.items():
if type(key) != str:
raise TypeError(f"Keys must be strings in TMapDict, found {key}.")

if not isinstance(value, Hashable):
raise TypeError(f"Values must be Hashable in TMapDict, found {value}.")


def validate_fidelity_trial_evaluation(evaluation: TFidelityTrialEvaluation) -> None:
for parameterization, trial_evaluation in evaluation:
validate_parameterization(parameterization=parameterization)
validate_trial_evaluation(evaluation=trial_evaluation)


def validate_map_trial_evaluation(evaluation: TMapTrialEvaluation) -> None:
for map_dict, trial_evaluation in evaluation:
validate_map_dict(map_dict=map_dict)
validate_trial_evaluation(evaluation=trial_evaluation)


def validate_evaluation_outcome(outcome: TEvaluationOutcome) -> None:
"""Runtime validate that the supplied outcome has correct structure."""

if isinstance(outcome, dict):
# Check if outcome is TTrialEvaluation
validate_trial_evaluation(evaluation=outcome)

elif isinstance(outcome, list):
# Check if outcome is TFidelityTrialEvaluation or TMapTrialEvaluation
try:
validate_fidelity_trial_evaluation(evaluation=outcome) # pyre-ignore[6]
except Exception:
try:
validate_map_trial_evaluation(evaluation=outcome) # pyre-ignore[6]
except Exception:
raise TypeError(
"Expected either TFidelityTrialEvaluation or TMapTrialEvaluation, "
f"found {type(outcome)}"
)

else:
# Check if outcome is SingleMetricData
validate_single_metric_data(data=outcome)
8 changes: 0 additions & 8 deletions ax/utils/common/tests/test_typeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Dict

import numpy as np
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import (
checked_cast,
checked_cast_complex,
checked_cast_dict,
checked_cast_list,
checked_cast_optional,
Expand All @@ -37,12 +35,6 @@ def test_checked_cast_with_error_override(self) -> None:
float, 2, exception=NotImplementedError("foo() doesn't support ints")
)

def test_checked_cast_complex(self) -> None:
t = Dict[int, str]
self.assertEqual(checked_cast_complex(t, {1: "one"}), {1: "one"})
with self.assertRaises(ValueError):
checked_cast_complex(t, {"one": 1})

def test_checked_cast_list(self) -> None:
self.assertEqual(checked_cast_list(float, [1.0, 2.0]), [1.0, 2.0])
with self.assertRaises(ValueError):
Expand Down
31 changes: 1 addition & 30 deletions ax/utils/common/typeutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, cast, Dict, List, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar

import numpy as np
from typeguard import check_type


T = TypeVar("T")
Expand Down Expand Up @@ -61,34 +60,6 @@ def checked_cast(typ: Type[T], val: V, exception: Optional[Exception] = None) ->
return val


def checked_cast_complex(typ: Type[T], val: V, message: Optional[str] = None) -> T:
"""
Cast a value to a type (with a runtime safety check). Used for subscripted generics
which isinstance cannot run against.
Returns the value unchanged and checks its type at runtime. This signals to the
typechecker that the value has the designated type.
Like `typing.cast`_ ``check_cast`` performs no runtime conversion on its argument,
but, unlike ``typing.cast``, ``checked_cast`` will throw an error if the value is
not of the expected type.
Args:
typ: the type to cast to
val: the value that we are casting
message: message to print on error
Returns:
the ``val`` argument casted to typ
.. _typing.cast: https://docs.python.org/3/library/typing.html#typing.cast
"""
try:
check_type("val", val, typ)
return cast(T, val)
except TypeError:
raise ValueError(message or f"Value was not of type {typ}: {val}")


def checked_cast_optional(typ: Type[T], val: Optional[V]) -> Optional[T]:
"""Calls checked_cast only if value is not None."""
if val is None:
Expand Down

0 comments on commit 139e9a8

Please sign in to comment.