25
25
def eval (
26
26
model_name : str ,
27
27
model_config : dict ,
28
- csv_path ,
28
+ dataset_path ,
29
29
checkpoint_path ,
30
30
labelmap_path ,
31
31
output_dir ,
@@ -48,8 +48,8 @@ def eval(
48
48
Model configuration in a ``dict``,
49
49
as loaded from a .toml file,
50
50
and used by the model method ``from_config``.
51
- csv_path : str, pathlib.Path
52
- path to where dataset was saved as a csv.
51
+ dataset_path : str, pathlib.Path
52
+ Path to dataset, e.g., a csv file generated by running ``vak prep`` .
53
53
checkpoint_path : str, pathlib.Path
54
54
path to directory with checkpoint files saved by Torch, to reload model
55
55
output_dir : str, pathlib.Path
@@ -105,8 +105,8 @@ def eval(
105
105
"""
106
106
# ---- pre-conditions ----------------------------------------------------------------------------------------------
107
107
for path , path_name in zip (
108
- (checkpoint_path , csv_path , labelmap_path , spect_scaler_path ),
109
- ('checkpoint_path' , 'csv_path ' , 'labelmap_path' , 'spect_scaler_path' ),
108
+ (checkpoint_path , dataset_path , labelmap_path , spect_scaler_path ),
109
+ ('checkpoint_path' , 'dataset_path ' , 'labelmap_path' , 'spect_scaler_path' ),
110
110
):
111
111
if path is not None : # because `spect_scaler_path` is optional
112
112
if not validators .is_a_file (path ):
@@ -148,9 +148,9 @@ def eval(
148
148
window_size = window_size ,
149
149
return_padding_mask = True ,
150
150
)
151
- logger .info (f"creating dataset for evaluation from: { csv_path } " )
151
+ logger .info (f"creating dataset for evaluation from: { dataset_path } " )
152
152
val_dataset = VocalDataset .from_csv (
153
- csv_path = csv_path ,
153
+ csv_path = dataset_path ,
154
154
split = split ,
155
155
labelmap = labelmap ,
156
156
spect_key = spect_key ,
@@ -173,7 +173,7 @@ def eval(
173
173
input_shape = input_shape [1 :]
174
174
175
175
if post_tfm_kwargs :
176
- dataset_df = pd .read_csv (csv_path )
176
+ dataset_df = pd .read_csv (dataset_path )
177
177
# we use the timebins vector from the first spect path to get timebin dur.
178
178
# this is less careful than calling io.dataframe.validate_and_get_timebin_dur
179
179
# but it's also much faster, and we can assume dataframe was validated when it was made
@@ -227,7 +227,7 @@ def eval(
227
227
("checkpoint_path" , checkpoint_path ),
228
228
("labelmap_path" , labelmap_path ),
229
229
("spect_scaler_path" , spect_scaler_path ),
230
- ("csv_path " , csv_path ),
230
+ ("dataset_path " , dataset_path ),
231
231
]
232
232
)
233
233
# TODO: is this still necessary after switching to Lightning? Stop saying "average"?
0 commit comments