Skip to content

Commit ae3bca8

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 1a4cce7 commit ae3bca8

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/optimagic/optimization/history.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def _get_next_batch_id(self) -> int:
9999
# Function data, function value, and monotone function value
100100
# ----------------------------------------------------------------------------------
101101

102-
def fun_data(self, cost_model: CostModel, monotone: bool, dropna: bool = False) -> pd.DataFrame:
102+
def fun_data(
103+
self, cost_model: CostModel, monotone: bool, dropna: bool = False
104+
) -> pd.DataFrame:
103105
"""Return the function value data.
104106
105107
Args:
@@ -119,14 +121,15 @@ def fun_data(self, cost_model: CostModel, monotone: bool, dropna: bool = False)
119121
timings = self._get_total_timings(cost_model)
120122

121123
if not self.is_serial:
122-
123124
timings = _apply_reduction_to_batches(
124125
data=timings,
125126
batch_ids=self.batches,
126127
reduction_function=cost_model.aggregate_batch_time,
127128
)
128129

129-
min_or_max = np.nanmin if self.direction == Direction.MINIMIZE else np.nanmax
130+
min_or_max = (
131+
np.nanmin if self.direction == Direction.MINIMIZE else np.nanmax
132+
)
130133
fun = _apply_reduction_to_batches(
131134
data=fun,
132135
batch_ids=self.batches,
@@ -138,7 +141,7 @@ def fun_data(self, cost_model: CostModel, monotone: bool, dropna: bool = False)
138141

139142
if self.is_serial:
140143
data["task"] = _task_to_categorical(self.task)
141-
144+
142145
if dropna:
143146
data = data.dropna()
144147

@@ -202,7 +205,6 @@ def params_data(self, dropna: bool = False) -> pd.DataFrame:
202205
# 2. Make long
203206

204207
if not self.is_serial:
205-
206208
if self.direction == Direction.MINIMIZE:
207209
loc = data.groupby("batches")["fun"].idxmin()
208210
elif self.direction == Direction.MAXIMIZE:
@@ -215,17 +217,19 @@ def params_data(self, dropna: bool = False) -> pd.DataFrame:
215217
data = data.drop(columns=["batches", "fun"])
216218

217219
long = pd.melt(
218-
wide, var_name="name", value_name="value", id_vars=["task", "batches", "fun"]
220+
wide,
221+
var_name="name",
222+
value_name="value",
223+
id_vars=["task", "batches", "fun"],
219224
)
220225

221226
data = long.reindex(columns=["name", "value", "task", "batches", "fun"])
222-
227+
223228
if dropna:
224229
data = data.dropna()
225230

226231
return data.rename_axis("counter")
227232

228-
229233
@property
230234
def params(self) -> list[PyTree]:
231235
return self._params
@@ -272,7 +276,6 @@ def _get_total_timings(
272276

273277
return fun_time + jac_time + fun_and_jac_time
274278

275-
276279
def _get_timings_per_task(
277280
self, task: EvalTask, cost_factor: float | None
278281
) -> NDArray[np.float64]:
@@ -314,7 +317,7 @@ def stop_time(self) -> list[float]:
314317
@property
315318
def batches(self) -> list[int]:
316319
return self._batches
317-
320+
318321
@property
319322
def is_serial(self) -> bool:
320323
return all(self.batches == np.arange(len(self.batches)))
@@ -455,7 +458,7 @@ def _apply_reduction_to_batches(
455458
456459
Returns:
457460
The transformed data. Has one entry per unique batch id, equal to the result of
458-
applying the reduction function to the data of that batch.
461+
applying the reduction function to the data of that batch.
459462
460463
"""
461464
batch_starts, batch_stops = _get_batch_starts_and_stops(batch_ids)

tests/optimagic/optimization/test_history.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,9 @@ def test_history_fun_data_with_fun_evaluations_cost_model(history: History):
201201
assert_frame_equal(got, exp, check_dtype=False, check_categorical=False)
202202

203203

204-
def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(history: History):
204+
def test_history_fun_data_with_fun_evaluations_cost_model_and_monotone(
205+
history: History,
206+
):
205207
got = history.fun_data(
206208
cost_model=om.timing.fun_evaluations,
207209
monotone=True,
@@ -561,4 +563,6 @@ def test_apply_to_batch_func_with_non_scalar_return():
561563
data = np.array([0, 1, 2, 3, 4])
562564
batch_ids = [0, 0, 1, 1, 2]
563565
with pytest.raises(ValueError, match="Function <lambda> did not return a scalar"):
564-
_apply_reduction_to_batches(data, batch_ids, reduction_function=lambda _list: _list)
566+
_apply_reduction_to_batches(
567+
data, batch_ids, reduction_function=lambda _list: _list
568+
)

0 commit comments

Comments
 (0)