@@ -99,7 +99,9 @@ def _get_next_batch_id(self) -> int:
99
99
# Function data, function value, and monotone function value
100
100
# ----------------------------------------------------------------------------------
101
101
102
- def fun_data (self , cost_model : CostModel , monotone : bool , dropna : bool = False ) -> pd .DataFrame :
102
+ def fun_data (
103
+ self , cost_model : CostModel , monotone : bool , dropna : bool = False
104
+ ) -> pd .DataFrame :
103
105
"""Return the function value data.
104
106
105
107
Args:
@@ -119,14 +121,15 @@ def fun_data(self, cost_model: CostModel, monotone: bool, dropna: bool = False)
119
121
timings = self ._get_total_timings (cost_model )
120
122
121
123
if not self .is_serial :
122
-
123
124
timings = _apply_reduction_to_batches (
124
125
data = timings ,
125
126
batch_ids = self .batches ,
126
127
reduction_function = cost_model .aggregate_batch_time ,
127
128
)
128
129
129
- min_or_max = np .nanmin if self .direction == Direction .MINIMIZE else np .nanmax
130
+ min_or_max = (
131
+ np .nanmin if self .direction == Direction .MINIMIZE else np .nanmax
132
+ )
130
133
fun = _apply_reduction_to_batches (
131
134
data = fun ,
132
135
batch_ids = self .batches ,
@@ -138,7 +141,7 @@ def fun_data(self, cost_model: CostModel, monotone: bool, dropna: bool = False)
138
141
139
142
if self .is_serial :
140
143
data ["task" ] = _task_to_categorical (self .task )
141
-
144
+
142
145
if dropna :
143
146
data = data .dropna ()
144
147
@@ -202,7 +205,6 @@ def params_data(self, dropna: bool = False) -> pd.DataFrame:
202
205
# 2. Make long
203
206
204
207
if not self .is_serial :
205
-
206
208
if self .direction == Direction .MINIMIZE :
207
209
loc = data .groupby ("batches" )["fun" ].idxmin ()
208
210
elif self .direction == Direction .MAXIMIZE :
@@ -215,17 +217,19 @@ def params_data(self, dropna: bool = False) -> pd.DataFrame:
215
217
data = data .drop (columns = ["batches" , "fun" ])
216
218
217
219
long = pd .melt (
218
- wide , var_name = "name" , value_name = "value" , id_vars = ["task" , "batches" , "fun" ]
220
+ wide ,
221
+ var_name = "name" ,
222
+ value_name = "value" ,
223
+ id_vars = ["task" , "batches" , "fun" ],
219
224
)
220
225
221
226
data = long .reindex (columns = ["name" , "value" , "task" , "batches" , "fun" ])
222
-
227
+
223
228
if dropna :
224
229
data = data .dropna ()
225
230
226
231
return data .rename_axis ("counter" )
227
232
228
-
229
233
@property
230
234
def params (self ) -> list [PyTree ]:
231
235
return self ._params
@@ -272,7 +276,6 @@ def _get_total_timings(
272
276
273
277
return fun_time + jac_time + fun_and_jac_time
274
278
275
-
276
279
def _get_timings_per_task (
277
280
self , task : EvalTask , cost_factor : float | None
278
281
) -> NDArray [np .float64 ]:
@@ -314,7 +317,7 @@ def stop_time(self) -> list[float]:
314
317
@property
315
318
def batches (self ) -> list [int ]:
316
319
return self ._batches
317
-
320
+
318
321
@property
319
322
def is_serial (self ) -> bool :
320
323
return all (self .batches == np .arange (len (self .batches )))
@@ -455,7 +458,7 @@ def _apply_reduction_to_batches(
455
458
456
459
Returns:
457
460
The transformed data. Has one entry per unique batch id, equal to the result of
458
- applying the reduction function to the data of that batch.
461
+ applying the reduction function to the data of that batch.
459
462
460
463
"""
461
464
batch_starts , batch_stops = _get_batch_starts_and_stops (batch_ids )
0 commit comments