Skip to content

Commit

Permalink
Rely on inheritted _make_metric as much as possible (facebook#2718)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2718

Simplify code.  Also references to super should help deobscure where this gets called/

Reviewed By: sdaulton

Differential Revision: D61852855
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Aug 29, 2024
1 parent 1a322b0 commit b29c18c
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import enum
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass

from logging import Logger
Expand Down Expand Up @@ -139,11 +140,12 @@ def _get_deserialized_metric_kwargs(
) -> dict[str, Any]:
"""Get metric kwargs from metric_definitions if available and deserialize
if so. Deserialization is necessary because they were serialized on creation"""
metric_kwargs = (metric_definitions or {}).get(name, {})
# deepcopy is used because of subsequent modifications to the dict
metric_kwargs = deepcopy((metric_definitions or {}).get(name, {}))
metric_class = metric_kwargs.pop("metric_class", metric_class)
metric_kwargs["name"] = name
# this is necessary before deserialization because name will be required
metric_kwargs["name"] = metric_kwargs.get("name", name)
metric_kwargs = metric_class.deserialize_init_args(metric_kwargs)
metric_kwargs.pop("name")
return metric_kwargs

@classmethod
Expand All @@ -160,15 +162,15 @@ def _make_metric(
"Metric names cannot contain spaces when used with AxClient. Got "
f"{name!r}."
)

return metric_class(
kwargs = cls._get_deserialized_metric_kwargs(
name=name,
lower_is_better=lower_is_better,
**cls._get_deserialized_metric_kwargs(
name=name,
metric_definitions=metric_definitions,
metric_class=metric_class,
),
metric_definitions=metric_definitions,
metric_class=metric_class,
)
# avoid conflict is lower_is_better is specified in kwargs
kwargs["lower_is_better"] = kwargs.get("lower_is_better", lower_is_better)
return metric_class(
**kwargs,
)

@staticmethod
Expand Down

0 comments on commit b29c18c

Please sign in to comment.