Skip to content

Commit bd50a5d

Browse files
committed
Add more test cases for invalid user funcs; check that batch time aggregator returns scalars
1 parent e089347 commit bd50a5d

File tree

4 files changed

+77
-17
lines changed

4 files changed

+77
-17
lines changed

src/optimagic/optimization/history.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -219,12 +219,12 @@ def _get_time(
219219
)
220220

221221
time = fun_time + jac_time + fun_and_jac_time
222-
batch_time = _batch_apply(
222+
batch_aware_time = _apply_to_batch(
223223
data=time,
224224
batch_ids=self.batches,
225225
func=cost_model.aggregate_batch_time,
226226
)
227-
return np.cumsum(batch_time)
227+
return np.cumsum(batch_aware_time)
228228

229229
def _get_time_per_task(
230230
self, task: EvalTask, cost_factor: float | None
@@ -383,7 +383,7 @@ def _task_as_categorical(task: list[EvalTask]) -> pd.Categorical:
383383
)
384384

385385

386-
def _batch_apply(
386+
def _apply_to_batch(
387387
data: NDArray[np.float64],
388388
batch_ids: list[int],
389389
func: Callable[[Iterable[float]], float],
@@ -392,10 +392,10 @@ def _batch_apply(
392392
393393
Args:
394394
data: 1d array with data.
395-
batch_ids: A list whose length is equal to the size of data. Values need to be
396-
sorted and can be repeated.
395+
batch_ids: A list with batch ids whose length is equal to the size of data.
396+
Values need to be sorted and can be repeated.
397397
func: A reduction function that takes an iterable of floats as input (e.g., a
398-
numpy array or a list) and returns a scalar.
398+
numpy.ndarray or list) and returns a scalar.
399399
400400
Returns:
401401
The transformed data. Has the same length as data. For each batch, the result of
@@ -410,25 +410,37 @@ def _batch_apply(
410410
for batch, (start, stop) in zip(
411411
batch_ids, zip(batch_starts, batch_stops, strict=False), strict=False
412412
):
413+
batch_data = data[start:stop]
414+
413415
try:
414-
batch_data = data[start:stop]
415416
reduced = func(batch_data)
416-
batch_results.append(reduced)
417417
except Exception as e:
418418
msg = (
419419
f"Calling function {func.__name__} on batch {batch} of the History "
420-
f"History raised an Exception. Please verify that {func.__name__} is "
421-
"properly defined."
420+
f"raised an Exception. Please verify that {func.__name__} is "
421+
"well-defined and takes a list of floats as input and returns a scalar."
422422
)
423423
raise ValueError(msg) from e
424424

425+
try:
426+
assert np.isscalar(reduced)
427+
except AssertionError:
428+
msg = (
429+
f"Function {func.__name__} did not return a scalar for batch {batch}. "
430+
f"Please verify that {func.__name__} returns a scalar when called on a "
431+
"list of floats."
432+
)
433+
raise ValueError(msg) from None
434+
435+
batch_results.append(reduced)
436+
425437
out = np.zeros_like(data)
426438
out[batch_starts] = batch_results
427439
return out
428440

429441

430442
def _get_batch_start(batch_ids: list[int]) -> list[int]:
431-
"""Get start indices of batch.
443+
"""Get start indices of batches.
432444
433445
This function assumes that batch_ids non-empty and sorted.
434446

src/optimagic/timing.py

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ class CostModel:
1010
label: str
1111
aggregate_batch_time: Callable[[Iterable[float]], float]
1212

13+
def __post_init__(self) -> None:
14+
if not callable(self.aggregate_batch_time):
15+
raise ValueError(
16+
"aggregate_batch_time must be a callable, got "
17+
f"{self.aggregate_batch_time}"
18+
)
19+
1320

1421
evaluation_time = CostModel(
1522
fun=None,

tests/optimagic/optimization/test_history.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from optimagic.optimization.history import (
1111
History,
1212
HistoryEntry,
13-
_batch_apply,
13+
_apply_to_batch,
1414
_calculate_monotone_sequence,
1515
_get_batch_start,
1616
_get_flat_param_names,
@@ -403,6 +403,13 @@ def test_get_time_wall_time(history):
403403
assert_array_equal(got, exp)
404404

405405

406+
def test_get_time_invalid_cost_model(history):
407+
with pytest.raises(
408+
ValueError, match="cost_model must be a CostModel or 'wall_time'."
409+
):
410+
history._get_time(cost_model="invalid")
411+
412+
406413
def test_start_time_property(history):
407414
assert history.start_time == [0, 2, 5, 7, 10, 12]
408415

@@ -465,12 +472,18 @@ def test_get_flat_params_fast_path():
465472
assert_array_equal(got, exp)
466473

467474

468-
def test_get_flat_param_names():
475+
def test_get_flat_param_names_pytree():
469476
got = _get_flat_param_names(param={"a": 0, "b": [0, 1], "c": np.arange(2)})
470477
exp = ["a", "b_0", "b_1", "c_0", "c_1"]
471478
assert got == exp
472479

473480

481+
def test_get_flat_param_names_fast_path():
482+
got = _get_flat_param_names(param=np.arange(2))
483+
exp = ["0", "1"]
484+
assert got == exp
485+
486+
474487
def test_calculate_monotone_sequence_maximize():
475488
sequence = [0, 1, 0, 0, 2, 10, 0]
476489
exp = [0, 1, 1, 1, 2, 10, 10]
@@ -509,17 +522,31 @@ def test_get_batch_start():
509522
assert got == [0, 2, 5, 7]
510523

511524

512-
def test_batch_apply_sum():
525+
def test_apply_to_batch_sum():
513526
data = np.array([0, 1, 2, 3, 4])
514527
batch_ids = [0, 0, 1, 1, 2]
515528
exp = np.array([1, 0, 5, 0, 4])
516-
got = _batch_apply(data, batch_ids, sum)
529+
got = _apply_to_batch(data, batch_ids, sum)
517530
assert_array_equal(exp, got)
518531

519532

520-
def test_batch_apply_max():
533+
def test_apply_to_batch_max():
521534
data = np.array([0, 1, 2, 3, 4])
522535
batch_ids = [0, 0, 1, 1, 2]
523536
exp = np.array([1, 0, 3, 0, 4])
524-
got = _batch_apply(data, batch_ids, max)
537+
got = _apply_to_batch(data, batch_ids, max)
525538
assert_array_equal(exp, got)
539+
540+
541+
def test_apply_to_batch_broken_func():
542+
data = np.array([0, 1, 2, 3, 4])
543+
batch_ids = [0, 0, 1, 1, 2]
544+
with pytest.raises(ValueError, match="Calling function <lambda> on batch [0, 0]"):
545+
_apply_to_batch(data, batch_ids, func=lambda _: 1 / 0)
546+
547+
548+
def test_apply_to_batch_func_with_non_scalar_return():
549+
data = np.array([0, 1, 2, 3, 4])
550+
batch_ids = [0, 0, 1, 1, 2]
551+
with pytest.raises(ValueError, match="Function <lambda> did not return a scalar"):
552+
_apply_to_batch(data, batch_ids, func=lambda _list: _list)

tests/optimagic/test_timing.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import pytest
2+
3+
from optimagic import timing
4+
5+
6+
def test_invalid_aggregate_batch_time():
7+
with pytest.raises(ValueError, match="aggregate_batch_time must be a callable"):
8+
timing.CostModel(
9+
fun=None,
10+
jac=None,
11+
fun_and_jac=None,
12+
label="label",
13+
aggregate_batch_time="Not callable",
14+
)

0 commit comments

Comments
 (0)