Skip to content

Commit

Permalink
Refactor objective transform interface (#398)
Browse files Browse the repository at this point in the history
Improves the `Objective.transform` interface and introduces explicit
column validation, similar to what has been done for search spaces in
#289.
  • Loading branch information
AdrianSosic authored Oct 18, 2024
2 parents e045529 + a76e008 commit eb709d9
Show file tree
Hide file tree
Showing 14 changed files with 257 additions and 72 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added
- `allow_missing` and `allow_extra` keyword arguments to `Objective.transform`

### Deprecations
- Passing a dataframe via the `data` argument to `Objective.transform` is no longer
possible. The dataframe must now be passed as positional argument.
- The new `allow_extra` flag is automatically set to `True` in `Objective.transform`
when left unspecified
- `get_transform_parameters` has been replaced with `get_transform_objects`

## [0.11.2] - 2024-10-11
### Added
- `n_restarts` and `n_raw_samples` keywords to configure continuous optimization
Expand Down
22 changes: 18 additions & 4 deletions baybe/objectives/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,29 @@ def targets(self) -> tuple[Target, ...]:
"""The targets included in the objective."""

@abstractmethod
def transform(self, data: pd.DataFrame) -> pd.DataFrame:
def transform(
self,
df: pd.DataFrame,
/,
*,
allow_missing: bool = False,
allow_extra: bool = False,
) -> pd.DataFrame:
"""Transform target values from experimental to computational representation.
Args:
data: The data to be transformed. Must contain columns for all targets
but can contain additional columns.
df: The dataframe to be transformed. The allowed columns of the dataframe
are dictated by the ``allow_missing`` and ``allow_extra`` flags.
allow_missing: If ``False``, each target of the objective must have
exactly one corresponding column in the given dataframe. If ``True``,
the dataframe may contain only a subset of target columns.
allow_extra: If ``False``, each column present in the dataframe must
correspond to exactly one target of the objective. If ``True``, the
dataframe may contain additional non-target-related columns, which
will be ignored.
Returns:
A new dataframe with the targets in computational representation.
A corresponding dataframe with the targets in computational representation.
"""


Expand Down
52 changes: 48 additions & 4 deletions baybe/objectives/desirability.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functionality for desirability objectives."""

import gc
import warnings
from collections.abc import Callable
from functools import cached_property, partial
from typing import TypeGuard
Expand All @@ -18,7 +19,7 @@
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import to_tuple
from baybe.utils.dataframe import pretty_print_df
from baybe.utils.dataframe import get_transform_objects, pretty_print_df
from baybe.utils.numerical import geom_mean
from baybe.utils.plotting import to_string
from baybe.utils.validation import finite_float
Expand Down Expand Up @@ -140,11 +141,54 @@ def __str__(self) -> str:
return to_string("Objective", *fields)

@override
def transform(self, data: pd.DataFrame) -> pd.DataFrame:
def transform(
self,
df: pd.DataFrame | None = None,
/,
*,
allow_missing: bool = False,
allow_extra: bool | None = None,
data: pd.DataFrame | None = None,
) -> pd.DataFrame:
# >>>>>>>>>> Deprecation
if not ((df is None) ^ (data is None)):
raise ValueError(
"Provide the dataframe to be transformed as argument to `df`."
)

if data is not None:
df = data
warnings.warn(
"Providing the dataframe via the `data` argument is deprecated and "
"will be removed in a future version. Please pass your dataframe "
"as positional argument instead.",
DeprecationWarning,
)

# Mypy does not infer from the above that `df` must be a dataframe here
assert isinstance(df, pd.DataFrame)

if allow_extra is None:
allow_extra = True
if set(df.columns) - {p.name for p in self.targets}:
warnings.warn(
"For backward compatibility, the new `allow_extra` flag is set "
"to `True` when left unspecified. However, this behavior will be "
"changed in a future version. If you want to invoke the old "
"behavior, please explicitly set `allow_extra=True`.",
DeprecationWarning,
)
# <<<<<<<<<< Deprecation

# Extract the relevant part of the dataframe
targets = get_transform_objects(
df, self.targets, allow_missing=allow_missing, allow_extra=allow_extra
)
transformed = df[[t.name for t in targets]].copy()

# Transform all targets individually
transformed = data[[t.name for t in self.targets]].copy()
for target in self.targets:
transformed[target.name] = target.transform(data[[target.name]])
transformed[target.name] = target.transform(df[[target.name]])

# Scalarize the transformed targets into desirability values
vals = scalarize(transformed.values, self.scalarizer, self._normalized_weights)
Expand Down
51 changes: 48 additions & 3 deletions baybe/objectives/single.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functionality for single-target objectives."""

import gc
import warnings

import pandas as pd
from attrs import define, field
Expand All @@ -9,7 +10,7 @@

from baybe.objectives.base import Objective
from baybe.targets.base import Target
from baybe.utils.dataframe import pretty_print_df
from baybe.utils.dataframe import get_transform_objects, pretty_print_df
from baybe.utils.plotting import to_string


Expand Down Expand Up @@ -38,8 +39,52 @@ def targets(self) -> tuple[Target, ...]:
return (self._target,)

@override
def transform(self, data: pd.DataFrame) -> pd.DataFrame:
target_data = data[[self._target.name]].copy()
def transform(
self,
df: pd.DataFrame | None = None,
/,
*,
allow_missing: bool = False,
allow_extra: bool | None = None,
data: pd.DataFrame | None = None,
) -> pd.DataFrame:
# >>>>>>>>>> Deprecation
if not ((df is None) ^ (data is None)):
raise ValueError(
"Provide the dataframe to be transformed as argument to `df`."
)

if data is not None:
df = data
warnings.warn(
"Providing the dataframe via the `data` argument is deprecated and "
"will be removed in a future version. Please pass your dataframe "
"as positional argument instead.",
DeprecationWarning,
)

# Mypy does not infer from the above that `df` must be a dataframe here
assert isinstance(df, pd.DataFrame)

if allow_extra is None:
allow_extra = True
if set(df.columns) - {p.name for p in self.targets}:
warnings.warn(
"For backward compatibility, the new `allow_extra` flag is set "
"to `True` when left unspecified. However, this behavior will be "
"changed in a future version. If you want to invoke the old "
"behavior, please explicitly set `allow_extra=True`.",
DeprecationWarning,
)
# <<<<<<<<<< Deprecation

# Even for a single target, it is convenient to use the existing machinery
# instead of re-implementing the validation logic
targets = get_transform_objects(
df, [self._target], allow_missing=allow_missing, allow_extra=allow_extra
)
target_data = df[[t.name for t in targets]].copy()

return self._target.transform(target_data)


Expand Down
7 changes: 3 additions & 4 deletions baybe/searchspace/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
from baybe.parameters.base import ContinuousParameter
from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters
from baybe.searchspace.validation import (
get_transform_parameters,
validate_parameter_names,
)
from baybe.serialization import SerialMixin, converter, select_constructor_hook
from baybe.utils.basic import to_tuple
from baybe.utils.dataframe import pretty_print_df
from baybe.utils.dataframe import get_transform_objects, pretty_print_df
from baybe.utils.plotting import to_string

if TYPE_CHECKING:
Expand Down Expand Up @@ -343,8 +342,8 @@ def transform(
# <<<<<<<<<< Deprecation

# Extract the parameters to be transformed
parameters = get_transform_parameters(
self.parameters, df, allow_missing, allow_extra
parameters = get_transform_objects(
df, self.parameters, allow_missing=allow_missing, allow_extra=allow_extra
)

# Transform the parameters
Expand Down
6 changes: 3 additions & 3 deletions baybe/searchspace/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,10 @@ def transform(
The ``None`` default value is for temporary backward compatibility only
and will be removed in a future version.
allow_missing: If ``False``, each parameter of the space must have
(exactly) one corresponding column in the given dataframe. If ``True``,
exactly one corresponding column in the given dataframe. If ``True``,
the dataframe may contain only a subset of parameter columns.
allow_extra: If ``False``, every column present in the dataframe must
correspond to (exactly) one parameter of the space. If ``True``, the
allow_extra: If ``False``, each column present in the dataframe must
correspond to exactly one parameter of the space. If ``True``, the
dataframe may contain additional non-parameter-related columns, which
will be ignored.
The ``None`` default value is for temporary backward compatibility only
Expand Down
6 changes: 3 additions & 3 deletions baybe/searchspace/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from baybe.parameters.base import DiscreteParameter, Parameter
from baybe.parameters.utils import get_parameters_from_dataframe, sort_parameters
from baybe.searchspace.validation import (
get_transform_parameters,
validate_parameter_names,
validate_parameters,
)
Expand All @@ -37,6 +36,7 @@
from baybe.utils.dataframe import (
df_drop_single_value_columns,
fuzzy_row_match,
get_transform_objects,
pretty_print_df,
)
from baybe.utils.memory import bytes_to_human_readable
Expand Down Expand Up @@ -753,8 +753,8 @@ def transform(
# <<<<<<<<<< Deprecation

# Extract the parameters to be transformed
parameters = get_transform_parameters(
self.parameters, df, allow_missing, allow_extra
parameters = get_transform_objects(
df, self.parameters, allow_missing=allow_missing, allow_extra=allow_extra
)

# If the transformed values are not required, return an empty dataframe
Expand Down
48 changes: 12 additions & 36 deletions baybe/searchspace/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Validation functionality for search spaces."""

import warnings
from collections.abc import Collection, Sequence
from typing import TypeVar

Expand All @@ -8,6 +9,7 @@
from baybe.exceptions import EmptySearchSpaceError
from baybe.parameters import TaskParameter
from baybe.parameters.base import Parameter
from baybe.utils.dataframe import get_transform_objects

_T = TypeVar("_T", bound=Parameter)

Expand Down Expand Up @@ -49,41 +51,15 @@ def validate_parameters(parameters: Collection[Parameter]) -> None: # noqa: DOC
def get_transform_parameters(
parameters: Sequence[_T],
df: pd.DataFrame,
allow_missing: bool,
allow_extra: bool,
allow_missing: bool = False,
allow_extra: bool = False,
) -> list[_T]:
"""Extract the parameters relevant for transforming a given dataframe.
Args:
parameters: The parameters to be considered for transformation (provided
they have match in the given dataframe).
df: See :meth:`baybe.searchspace.core.SearchSpace.transform`.
allow_missing: See :meth:`baybe.searchspace.core.SearchSpace.transform`.
allow_extra: See :meth:`baybe.searchspace.core.SearchSpace.transform`.
Raises:
ValueError: If the given parameters and dataframe are not compatible
under the specified values for the Boolean flags.
Returns:
The (subset of) parameters that need to be considered for the transformation.
"""
parameter_names = [p.name for p in parameters]

if (not allow_missing) and (missing := set(parameter_names) - set(df)): # type: ignore[arg-type]
raise ValueError(
f"The search space parameter(s) {missing} cannot be matched against "
f"the provided dataframe. If you want to transform a subset of "
f"parameter columns, explicitly set `allow_missing=True`."
)

if (not allow_extra) and (extra := set(df) - set(parameter_names)):
raise ValueError(
f"The provided dataframe column(s) {extra} cannot be matched against"
f"the search space parameters. If you want to transform a dataframe "
f"with additional columns, explicitly set `allow_extra=True'."
)

return (
[p for p in parameters if p.name in df] if allow_missing else list(parameters)
"""Deprecated!""" # noqa: D401
warnings.warn(
f"The function 'get_transform_parameters' has been deprecated and will be "
f"removed in a future version. Use '{get_transform_objects.__name__}' instead.",
DeprecationWarning,
)
return get_transform_objects(
df, parameters, allow_missing=allow_missing, allow_extra=allow_extra
)
4 changes: 2 additions & 2 deletions baybe/surrogates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _make_output_scaler(
scaler = factory(1)

# TODO: Consider taking into account target boundaries when available
scaler(to_tensor(objective.transform(measurements)))
scaler(to_tensor(objective.transform(measurements, allow_extra=True)))
scaler.eval()

return scaler
Expand Down Expand Up @@ -336,7 +336,7 @@ def fit(
# Transform and fit
train_x_comp_rep, train_y_comp_rep = to_tensor(
searchspace.transform(measurements, allow_extra=True),
objective.transform(measurements),
objective.transform(measurements, allow_extra=True),
)
train_x = self._input_scaler.transform(train_x_comp_rep)
train_y = (
Expand Down
Loading

0 comments on commit eb709d9

Please sign in to comment.