Skip to content

Commit 90c5e6e

Browse files
committed
Integrate comments from review
1 parent 8607490 commit 90c5e6e

File tree

2 files changed

+143
-78
lines changed

2 files changed

+143
-78
lines changed

src/optimagic/optimization/history.py

+68-44
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import warnings
24
from dataclasses import dataclass
35
from functools import partial
@@ -99,17 +101,12 @@ def _get_next_batch_id(self) -> int:
99101
# Function data, function value, and monotone function value
100102
# ----------------------------------------------------------------------------------
101103

102-
def fun_data(
103-
self, cost_model: CostModel, monotone: bool = False, dropna: bool = False
104-
) -> pd.DataFrame:
104+
def fun_data(self, cost_model: CostModel, monotone: bool = False) -> pd.DataFrame:
105105
"""Return the function value data.
106106
107107
Args:
108108
cost_model: The cost model that is used to calculate the time measure.
109109
monotone: Whether to return the monotone function values. Defaults to False.
110-
dropna: Whether to drop rows with missing values. These correspond to
111-
parameters that were used to calculate a pure jacobian. Defaults to
112-
False.
113110
114111
Returns:
115112
pd.DataFrame: The function value data. The columns are: 'fun', 'time' and
@@ -124,8 +121,9 @@ def fun_data(
124121
fun = np.array(self.fun, dtype=np.float64) # converts None to nan
125122

126123
timings = self._get_total_timings(cost_model)
124+
task = _task_to_categorical(self.task)
127125

128-
if not self.is_serial:
126+
if not self._is_serial():
129127
# In the non-serial case, we take the batching into account and reduce
130128
# timings and fun to one value per batch.
131129
timings = _apply_reduction_to_batches(
@@ -143,18 +141,16 @@ def fun_data(
143141
reduction_function=min_or_max, # type: ignore[arg-type]
144142
)
145143

146-
time = np.cumsum(timings)
147-
data = pd.DataFrame({"fun": fun, "time": time})
148-
149-
if self.is_serial:
150-
# In the non-serial case, the task column is meaningless, since multiple
151-
# tasks would need to be reduced to one.
152-
data["task"] = _task_to_categorical(self.task)
144+
# Verify that tasks are homogeneous in each batch, and select first if true.
145+
tasks_and_batches = pd.DataFrame({"task": task, "batches": self.batches})
146+
grouped_tasks = tasks_and_batches.groupby("batches")["task"]
147+
if not grouped_tasks.nunique().eq(1).all():
148+
raise ValueError("Tasks are not homogeneous in each batch.")
153149

154-
if dropna:
155-
data = data.dropna()
150+
task = grouped_tasks.first().reset_index(drop=True)
156151

157-
return data.rename_axis("counter")
152+
time = np.cumsum(timings)
153+
return pd.DataFrame({"fun": fun, "time": time, "task": task})
158154

159155
@property
160156
def fun(self) -> list[float | None]:
@@ -191,13 +187,18 @@ def is_accepted(self) -> NDArray[np.bool_]:
191187
# Parameter data, params, flat params, and flat params names
192188
# ----------------------------------------------------------------------------------
193189

194-
def params_data(self, dropna: bool = False) -> pd.DataFrame:
190+
def params_data(
191+
self, dropna: bool = False, collapse_batches: bool = False
192+
) -> pd.DataFrame:
195193
"""Return the parameter data.
196194
197195
Args:
198-
dropna: Whether to drop rows with missing values. These correspond to
199-
parameters that were used to calculate a pure jacobian. Defaults to
196+
dropna: Whether to drop rows with missing function values. These correspond
197+
to parameters that were used to calculate pure jacobians. Defaults to
200198
False.
199+
collapse_batches: Whether to collapse the batches and only keep the
200+
parameters that led to the minimal (or maximal) function value in each
201+
batch. Defaults to False.
201202
202203
Returns:
203204
pd.DataFrame: The parameter data. The columns are: 'name' (the parameter
@@ -210,31 +211,50 @@ def params_data(self, dropna: bool = False) -> pd.DataFrame:
210211
wide["task"] = _task_to_categorical(self.task)
211212
wide["fun"] = self.fun
212213

213-
# In the batch case, we select only the parameters in a batch that led to the
214-
# minimal (or maximal) function value in that batch.
215-
if not self.is_serial:
214+
# If requested, we collapse the batches and only keep the parameters that led to
215+
# the minimal (or maximal) function value in each batch.
216+
if collapse_batches and not self._is_serial():
216217
wide["batches"] = self.batches
217218

219+
# Verify that tasks are homogeneous in each batch
220+
if not wide.groupby("batches")["task"].nunique().eq(1).all():
221+
raise ValueError("Tasks are not homogeneous in each batch.")
222+
223+
# We fill nans with inf or -inf to make sure that the idxmin/idxmax is
224+
# well-defined, since there is the possibility that all fun values are nans
225+
# in a batch.
218226
if self.direction == Direction.MINIMIZE:
219-
loc_of_fun_optimizer = wide.groupby("batches")["fun"].idxmin()
227+
loc = (
228+
wide.assign(fun_without_nan=wide["fun"].fillna(np.inf))
229+
.groupby("batches")["fun_without_nan"]
230+
.idxmin()
231+
)
220232
elif self.direction == Direction.MAXIMIZE:
221-
loc_of_fun_optimizer = wide.groupby("batches")["fun"].idxmax()
233+
loc = (
234+
wide.assign(fun_without_nan=wide["fun"].fillna(-np.inf))
235+
.groupby("batches")["fun_without_nan"]
236+
.idxmax()
237+
)
238+
239+
wide = wide.loc[loc].drop(columns="batches")
222240

223-
wide = wide.loc[loc_of_fun_optimizer].drop(columns="batches")
241+
# We drop rows with missing values if requested. These correspond to parameters
242+
# that were used to calculate pure jacobians. This step must be done before
243+
# dropping the fun column and before setting the counter.
244+
if dropna:
245+
wide = wide.dropna()
246+
wide = wide.drop(columns="fun")
224247

225248
wide["counter"] = np.arange(len(wide))
226249

227250
long = pd.melt(
228251
wide,
229252
var_name="name",
230253
value_name="value",
231-
id_vars=["task", "fun", "counter"],
254+
id_vars=["task", "counter"],
232255
)
233256

234-
data = long.reindex(columns=["counter", "name", "value", "task", "fun"])
235-
236-
if dropna:
237-
data = data.dropna()
257+
data = long.reindex(columns=["counter", "name", "value", "task"])
238258

239259
return data.set_index(["counter", "name"]).sort_index()
240260

@@ -326,8 +346,7 @@ def stop_time(self) -> list[float]:
326346
def batches(self) -> list[int]:
327347
return self._batches
328348

329-
@property
330-
def is_serial(self) -> bool:
349+
def _is_serial(self) -> bool:
331350
return np.array_equal(self.batches, np.arange(len(self.batches)))
332351

333352
# Tasks
@@ -373,7 +392,7 @@ def __getitem__(self, key: str) -> Any:
373392

374393

375394
# ======================================================================================
376-
# Methods
395+
# Functions directly used in History methods
377396
# ======================================================================================
378397

379398

@@ -442,10 +461,9 @@ def _validate_args_are_all_none_or_lists_of_same_length(
442461
raise ValueError("All arguments must be lists of the same length or None.")
443462

444463

445-
def _task_to_categorical(task: list[EvalTask]) -> pd.Categorical:
446-
return pd.Categorical(
447-
[t.value for t in task], categories=[t.value for t in EvalTask]
448-
)
464+
def _task_to_categorical(task: list[EvalTask]) -> pd.Series[str]:
465+
EvalTaskDtype = pd.CategoricalDtype(categories=[t.value for t in EvalTask])
466+
return pd.Series([t.value for t in task], dtype=EvalTaskDtype)
449467

450468

451469
def _apply_reduction_to_batches(
@@ -462,7 +480,8 @@ def _apply_reduction_to_batches(
462480
batch_ids: A list with batch ids whose length is equal to the size of data.
463481
Values need to be sorted and can be repeated.
464482
reduction_function: A reduction function that takes an iterable of floats as
465-
input (e.g., a numpy.ndarray or list) and returns a scalar.
483+
input (e.g., a numpy.ndarray or list of floats) and returns a scalar. The
484+
function must be able to handle NaN's.
466485
467486
Returns:
468487
The transformed data. Has one entry per unique batch id, equal to the result of
@@ -478,21 +497,26 @@ def _apply_reduction_to_batches(
478497
batch_id = batch_ids[start]
479498

480499
try:
481-
reduced = reduction_function(batch_data)
500+
if np.isnan(batch_data).all():
501+
reduced = np.nan
502+
else:
503+
reduced = reduction_function(batch_data)
482504
except Exception as e:
483505
msg = (
484506
f"Calling function {reduction_function.__name__} on batch {batch_id} "
485507
"of the History raised an Exception. Please verify that "
486-
f"{reduction_function.__name__} is well-defined, takes a list of "
487-
"floats as input, and returns a scalar."
508+
f"{reduction_function.__name__} is well-defined, takes an iterable of "
509+
"floats as input and returns a scalar. The function must be able to "
510+
"handle NaN's."
488511
)
489512
raise ValueError(msg) from e
490513

491514
if not np.isscalar(reduced):
492515
msg = (
493516
f"Function {reduction_function.__name__} did not return a scalar for "
494517
f"batch {batch_id}. Please verify that {reduction_function.__name__} "
495-
"returns a scalar when called on a list of floats."
518+
"returns a scalar when called on an iterable of floats. The function "
519+
"must be able to handle NaN's."
496520
)
497521
raise ValueError(msg)
498522

@@ -504,7 +528,7 @@ def _apply_reduction_to_batches(
504528
def _get_batch_starts_and_stops(batch_ids: list[int]) -> tuple[list[int], list[int]]:
505529
"""Get start and stop indices of batches.
506530
507-
This function assumes that batch_ids non-empty and sorted.
531+
This function assumes that batch_ids are non-empty and sorted.
508532
509533
"""
510534
ids_arr = np.array(batch_ids, dtype=np.int64)

0 commit comments

Comments
 (0)