Skip to content

Commit 79508b7

Browse files
committed
ENH: Add eval with post-processing, fix #472
- Add post_tfm_kwargs to config/eval.py - Add post_tfm_kwargs attribute to LearncurveConfig - Add 'post_tfm_kwargs' option to config/valid.toml - Add post_tfm_kwargs to LEARNCURVE section of vak/config/valid.toml - Add use of post_tfm eval in engine.Model - Add post_tfm_kwargs to core.eval and use with model - Add logic in core/eval.py to use post_tfm_kwargs to make post_tfm - Use multi_char_labels_to_single_char in core.eval, not in transforms, to make sure edit distance is computed correctl - Add post_tfm parameter to vak.models.from_model_config_map - Add parameter and put in docstring, - Pass argument into Model.from_config - Add post_tfm_kwargs to TeenyTweetyNet.from_config - Add post_tfm_kwargs to unit test in test_core/test_eval.py - Pass post_tfm_kwargs into core.eval in cli/eval.py - Add parameter post_tfm_kwargs to vak.core.learncurve function, pass into calls to core.eval - Pass post_tfm_kwargs into core.learncurve inside cli.learncurve
1 parent 8fb7665 commit 79508b7

File tree

11 files changed

+230
-28
lines changed

11 files changed

+230
-28
lines changed

src/vak/cli/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from .. import (
55
config,
66
core,
7-
validators
87
)
98
from ..logging import config_logging_for_cli, log_version
109

@@ -65,4 +64,5 @@ def eval(toml_path):
6564
spect_key=cfg.spect_params.spect_key,
6665
timebins_key=cfg.spect_params.timebins_key,
6766
device=cfg.eval.device,
67+
post_tfm_kwargs=cfg.eval.post_tfm_kwargs,
6868
)

src/vak/cli/learncurve.py

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def learning_curve(toml_path):
7171
num_workers=cfg.learncurve.num_workers,
7272
results_path=results_path,
7373
previous_run_path=cfg.learncurve.previous_run_path,
74+
post_tfm_kwargs=cfg.learncurve.post_tfm_kwargs,
7475
spect_key=cfg.spect_params.spect_key,
7576
timebins_key=cfg.spect_params.timebins_key,
7677
normalize_spectrograms=cfg.learncurve.normalize_spectrograms,

src/vak/config/eval.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,63 @@
11
"""parses [EVAL] section of config"""
22
import attr
3-
from attr import converters
3+
from attr import converters, validators
44
from attr.validators import instance_of
55

66
from .validators import is_valid_model_name
77
from .. import device
88
from ..converters import comma_separated_list, expanded_user_path
99

1010

11+
def convert_post_tfm_kwargs(post_tfm_kwargs: dict) -> dict:
12+
post_tfm_kwargs = dict(post_tfm_kwargs)
13+
14+
if 'min_segment_dur' not in post_tfm_kwargs:
15+
# because there's no null in TOML,
16+
# users leave arg out of config then we set it to None
17+
post_tfm_kwargs['min_segment_dur'] = None
18+
else:
19+
post_tfm_kwargs['min_segment_dur'] = float(post_tfm_kwargs['min_segment_dur'])
20+
21+
if 'majority_vote' not in post_tfm_kwargs:
22+
# set default for this one too
23+
post_tfm_kwargs['majority_vote'] = False
24+
else:
25+
post_tfm_kwargs['majority_vote'] = bool(post_tfm_kwargs['majority_vote'])
26+
27+
return post_tfm_kwargs
28+
29+
30+
def are_valid_post_tfm_kwargs(instance, attribute, value):
31+
"""check if ``post_tfm_kwargs`` is valid"""
32+
if not isinstance(value, dict):
33+
raise TypeError(
34+
"'post_tfm_kwargs' should be declared in toml config as an inline table "
35+
f"that parses as a dict, but type was: {type(value)}. "
36+
"Please declare in a similar fashion: `{majority_vote = True, min_segment_dur = 0.02}`"
37+
)
38+
if any(
39+
[k not in {'majority_vote', 'min_segment_dur'} for k in value.keys()]
40+
):
41+
invalid_kwargs = [k for k in value.keys()
42+
if k not in {'majority_vote', 'min_segment_dur'}]
43+
raise ValueError(
44+
f"Invalid keyword argument name specified for 'post_tfm_kwargs': {invalid_kwargs}."
45+
"Valid names are: {'majority_vote', 'min_segment_dur'}"
46+
)
47+
if 'majority_vote' in value:
48+
if not isinstance(value['majority_vote'], bool):
49+
raise TypeError(
50+
"'post_tfm_kwargs' keyword argument 'majority_vote' "
51+
f"should be of type bool but was: {type(value['majority_vote'])}"
52+
)
53+
if 'min_segment_dur' in value:
54+
if value['min_segment_dur'] and not isinstance(value['min_segment_dur'], float):
55+
raise TypeError(
56+
"'post_tfm_kwargs' keyword argument 'min_segment_dur' type "
57+
f"should be float but was: {type(value['min_segment_dur'])}"
58+
)
59+
60+
1161
@attr.s
1262
class EvalConfig:
1363
"""class that represents [EVAL] section of config.toml file
@@ -36,6 +86,19 @@ class EvalConfig:
3686
path to a saved SpectScaler object used to normalize spectrograms.
3787
If spectrograms were normalized and this is not provided, will give
3888
incorrect results.
89+
post_tfm_kwargs : dict
90+
Keyword arguments to post-processing transform.
91+
If None, then no additional clean-up is applied
92+
when transforming labeled timebins to segments,
93+
the default behavior.
94+
The transform used is
95+
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
96+
Valid keyword argument names are 'majority_vote'
97+
and 'min_segment_dur', and should be appropriate
98+
values for those arguments: Boolean for ``majority_vote``,
99+
a float value for ``min_segment_dur``.
100+
See the docstring of the transform for more details on
101+
these arguments and how they work.
39102
"""
40103
# required, external files
41104
checkpoint_path = attr.ib(converter=expanded_user_path)
@@ -62,6 +125,12 @@ class EvalConfig:
62125
default=None,
63126
)
64127

128+
post_tfm_kwargs = attr.ib(
129+
validator=validators.optional(are_valid_post_tfm_kwargs),
130+
converter=converters.optional(convert_post_tfm_kwargs),
131+
default={}, # empty dict so we can pass into transform with **kwargs expansion
132+
)
133+
65134
# optional, data loader
66135
num_workers = attr.ib(validator=instance_of(int), default=2)
67136
device = attr.ib(validator=instance_of(str), default=device.get_default())

src/vak/config/learncurve.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""parses [LEARNCURVE] section of config"""
22
import attr
3-
from attr import converters
3+
from attr import converters, validators
44
from attr.validators import instance_of
55

6+
from .eval import are_valid_post_tfm_kwargs, convert_post_tfm_kwargs
67
from .train import TrainConfig
78
from ..converters import expanded_user_path
89

@@ -49,10 +50,29 @@ class LearncurveConfig(TrainConfig):
4950
previous_run_path : str
5051
path to results directory from a previous run.
5152
Used for training if use_train_subsets_from_previous_run is True.
53+
post_tfm_kwargs : dict
54+
Keyword arguments to post-processing transform.
55+
If None, then no additional clean-up is applied
56+
when transforming labeled timebins to segments,
57+
the default behavior.
58+
The transform used is
59+
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
60+
Valid keyword argument names are 'majority_vote'
61+
and 'min_segment_dur', and should be appropriate
62+
values for those arguments: Boolean for ``majority_vote``,
63+
a float value for ``min_segment_dur``.
64+
See the docstring of the transform for more details on
65+
these arguments and how they work.
5266
"""
5367
train_set_durs = attr.ib(validator=instance_of(list), kw_only=True)
5468
num_replicates = attr.ib(validator=instance_of(int), kw_only=True)
5569
previous_run_path = attr.ib(
5670
converter=converters.optional(expanded_user_path),
5771
default=None,
5872
)
73+
74+
post_tfm_kwargs = attr.ib(
75+
validator=validators.optional(are_valid_post_tfm_kwargs),
76+
converter=converters.optional(convert_post_tfm_kwargs),
77+
default={}, # empty dict so we can pass into transform with **kwargs expansion
78+
)

src/vak/config/valid.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ batch_size = 11
6262
num_workers = 4
6363
device = 'cuda'
6464
spect_scaler_path = '/home/user/results_181014_194418/spect_scaler'
65-
65+
post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}
6666

6767
[LEARNCURVE]
6868
models = 'TweetyNet'
@@ -79,6 +79,7 @@ num_replicates = 2
7979
csv_path = 'tests/test_data/prep/learncurve/032312_prep_191224_225910.csv'
8080
results_dir_made_by_main_script = '/some/path/to/learncurve/'
8181
previous_run_path = '/some/path/to/learncurve/results_20210106_132152'
82+
post_tfm_kwargs = {'majority_vote' = true, 'min_segment_dur' = 0.01}
8283
num_workers = 4
8384
device = 'cuda'
8485

src/vak/core/eval.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88
import torch.utils.data
99

1010
from .. import (
11+
files,
1112
models,
13+
timebins,
1214
transforms,
1315
validators
1416
)
1517
from ..datasets.vocal_dataset import VocalDataset
18+
from ..labels import multi_char_labels_to_single_char
1619

1720

1821
logger = logging.getLogger(__name__)
@@ -28,6 +31,7 @@ def eval(
2831
num_workers,
2932
split="test",
3033
spect_scaler_path=None,
34+
post_tfm_kwargs=None,
3135
spect_key="s",
3236
timebins_key="t",
3337
device=None,
@@ -64,13 +68,34 @@ def eval(
6468
If spectrograms were normalized and this is not provided, will give
6569
incorrect results.
6670
Default is None.
71+
post_tfm_kwargs : dict
72+
Keyword arguments to post-processing transform.
73+
If None, then no additional clean-up is applied
74+
when transforming labeled timebins to segments,
75+
the default behavior. The transform used is
76+
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
77+
Valid keyword argument names are 'majority_vote'
78+
and 'min_segment_dur', and should be appropriate
79+
values for those arguments: Boolean for ``majority_vote``,
80+
a float value for ``min_segment_dur``.
81+
See the docstring of the transform for more details on
82+
these arguments and how they work.
6783
spect_key : str
6884
key for accessing spectrogram in files. Default is 's'.
6985
timebins_key : str
7086
key for accessing vector of time bins in files. Default is 't'.
7187
device : str
7288
Device on which to work with model + data.
7389
Defaults to 'cuda' if torch.cuda.is_available is True.
90+
91+
Notes
92+
-----
93+
Note that unlike ``core.predict``, this function
94+
can modify ``labelmap`` so that metrics like edit distance
95+
are correctly computed, by converting any string labels
96+
in ``labelmap`` with multiple characters
97+
to (mock) single-character labels,
98+
with ``vak.labels.multi_char_labels_to_single_char``.
7499
"""
75100
# ---- pre-conditions ----------------------------------------------------------------------------------------------
76101
for path, path_name in zip(
@@ -102,6 +127,15 @@ def eval(
102127
with labelmap_path.open("r") as f:
103128
labelmap = json.load(f)
104129

130+
# replace any multiple character labels in mapping
131+
# with dummy single-character labels
132+
# so that we do not affect edit distance computation
133+
# see https://github.com/NickleDave/vak/issues/373
134+
labelmap_keys = [lbl for lbl in labelmap.keys() if lbl != 'unlabeled']
135+
if any([len(label) > 1 for label in labelmap_keys]): # only re-map if necessary
136+
# (to minimize chance of knock-on bugs)
137+
labelmap = multi_char_labels_to_single_char(labelmap)
138+
105139
item_transform = transforms.get_defaults(
106140
"eval",
107141
spect_standardizer,
@@ -132,8 +166,23 @@ def eval(
132166
if len(input_shape) == 4:
133167
input_shape = input_shape[1:]
134168

169+
if post_tfm_kwargs:
170+
dataset_df = pd.read_csv(csv_path)
171+
# we use the timebins vector from the first spect path to get timebin dur.
172+
# this is less careful than calling io.dataframe.validate_and_get_timebin_dur
173+
# but it's also much faster, and we can assume dataframe was validated when it was made
174+
spect_dict = files.spect.load(dataset_df['spect_path'].values[0])
175+
timebin_dur = timebins.timebin_dur_from_vec(spect_dict[timebins_key])
176+
177+
post_tfm = transforms.labeled_timebins.PostProcess(
178+
timebin_dur=timebin_dur,
179+
**post_tfm_kwargs,
180+
)
181+
else:
182+
post_tfm = None
183+
135184
models_map = models.from_model_config_map(
136-
model_config_map, num_classes=len(labelmap), input_shape=input_shape
185+
model_config_map, num_classes=len(labelmap), input_shape=input_shape, post_tfm=post_tfm
137186
)
138187

139188
for model_name, model in models_map.items():

src/vak/core/learncurve/learncurve.py

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
logger = logging.getLogger(__name__)
2020

2121

22+
# TODO: add post_tfm_kwargs here
2223
def learning_curve(
2324
model_config_map,
2425
train_set_durs,
@@ -32,6 +33,7 @@ def learning_curve(
3233
root_results_dir=None,
3334
results_path=None,
3435
previous_run_path=None,
36+
post_tfm_kwargs=None,
3537
spect_key="s",
3638
timebins_key="t",
3739
normalize_spectrograms=True,
@@ -86,6 +88,18 @@ def learning_curve(
8688
Typically directory will have a name like ``results_{timestamp}``
8789
and the actual .csv splits will be in sub-directories with names
8890
corresponding to the training set duration
91+
post_tfm_kwargs : dict
92+
Keyword arguments to post-processing transform.
93+
If None, then no additional clean-up is applied
94+
when transforming labeled timebins to segments,
95+
the default behavior. The transform used is
96+
``vak.transforms.labeled_timebins.ToSegmentsWithPostProcessing`.
97+
Valid keyword argument names are 'majority_vote'
98+
and 'min_segment_dur', and should be appropriate
99+
values for those arguments: Boolean for ``majority_vote``,
100+
a float value for ``min_segment_dur``.
101+
See the docstring of the transform for more details on
102+
these arguments and how they work.
89103
spect_key : str
90104
key for accessing spectrogram in files. Default is 's'.
91105
timebins_key : str
@@ -318,6 +332,7 @@ def learning_curve(
318332
num_workers=num_workers,
319333
split="test",
320334
spect_scaler_path=spect_scaler_path,
335+
post_tfm_kwargs=post_tfm_kwargs,
321336
spect_key=spect_key,
322337
timebins_key=timebins_key,
323338
device=device,

0 commit comments

Comments
 (0)