Skip to content

Commit 5ebe407

Browse files
authored
ENH: fix frame classification model to work with BioSoundSegBench (#774)
This PR consists mainly of fixes needed for FrameClassificationModel to work with the BioSoundSegBench dataset. * Hack a learning rate scheduler into FrameClassificationModel * Add boundary_labels parameter to transforms.frame_labels.transforms.PostProcess.__call__ * Add background_label to post_tfm_kwargs in src/vak/eval/frame_classification.py * In `validation_step` of FrameClassificationModel, use boundary labels when they are present to post-process multi-class frame labels' * Unpack `dataset_path` from dataset_config in code block for built-in datasets in eval/frame_classification.py, to make sure this variable exists when we build the DataFrame with eval results" * Make variable `frame_dur` inside code block for built-in datasets inside eval/frame_classification.py so this variable exists when we get the post-processing transform * Pass `background_label` into transforms.frame_labels.PostProcess inside eval/frame_classification.py, using `constants.DEFAULT_BACKGROUND_LABEL` * Fix how we call self.manual_backwward in FrameClassificationModel to handle the case when the loss function returns a dict * In FrameClassificationModel.validation_step, convert boundary_preds to numpy when we pass them in to self.post_tfm * In FrameClassificationModel.validation_step, when logging accuracy, call it 'val_multi_acc' to distinguish from boundary_acc and for consistency with val_multi_acc_tfm * Change how we get and log frame_dur in train/frame_classification.py so we have it as a separate variable; will use for post_tfm kwargs when we add those later * Change one-line summary of __call__ method for frame_labels.transforms.PostProcess * BUG: Ensure boundary_labels is 1d in post-process transform, fix #767 * Fix what metric we use for learning rate scheduler: use val_multi_acc for models with multiple accuracies * Remove trainer module from common, code is used only for frame classification model * Add get_trainer and get_callbacks to train/frame_classification.py, fix so that we monitor 'val_multi_acc' when a model has multiple targets, and just 'val_acc' otherwise * Add missing self.manual_backward in training_step of FrameClassificationModel * Fix how we determine whether there are multiple targets and what to monitor in train/frame_classification.py * Fix how we validate boundary_labels in transforms.frame_labels.functional.postprocess -- don't if boundary_labels is None * Fix vak/predict/frame_classification.py to handle edge case where no non-background segments are predicted for any sample in dataset * Revise comment * Catch edge case in transforms.frame_labels.functional.boundary_inds_from_frame_boundary_labels * Add minimal unit tests for vak.transforms.frame_labels.functional.boundary_inds_from_boundary_labels * Remove learning rate scheduler for now
1 parent f4afc5c commit 5ebe407

File tree

9 files changed

+190
-114
lines changed

9 files changed

+190
-114
lines changed

src/vak/common/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
tensorboard,
2222
timebins,
2323
timenow,
24-
trainer,
2524
typing,
2625
validators,
2726
)
@@ -39,7 +38,6 @@
3938
"tensorboard",
4039
"timebins",
4140
"timenow",
42-
"trainer",
4341
"typing",
4442
"validators",
4543
]

src/vak/common/trainer.py

-88
This file was deleted.

src/vak/eval/frame_classification.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.utils.data
1515

1616
from .. import datapipes, datasets, models, transforms
17-
from ..common import validators
17+
from ..common import constants, validators
1818
from ..datapipes.frame_classification import InferDatapipe
1919

2020
logger = logging.getLogger(__name__)
@@ -154,12 +154,20 @@ def eval_frame_classification_model(
154154
)
155155
# ---- *yes* using a built-in dataset ------------------------------------------------------------------------------
156156
else:
157+
# next line, we don't use dataset path in this code block,
158+
# but we need it below when we build the DataFrame with eval results.
159+
# we're unpacking it here just as we do above with a prep'd dataset
160+
dataset_path = pathlib.Path(dataset_config["path"])
157161
dataset_config["params"]["return_padding_mask"] = True
158162
val_dataset = datasets.get(
159163
dataset_config,
160164
split=split,
161165
frames_standardizer=frames_standardizer,
162166
)
167+
frame_dur = val_dataset.frame_dur
168+
logger.info(
169+
f"Duration of a frame in dataset, in seconds: {frame_dur}",
170+
)
163171

164172
val_loader = torch.utils.data.DataLoader(
165173
dataset=val_dataset,
@@ -179,6 +187,7 @@ def eval_frame_classification_model(
179187
if post_tfm_kwargs:
180188
post_tfm = transforms.frame_labels.PostProcess(
181189
timebin_dur=frame_dur,
190+
background_label=labelmap[constants.DEFAULT_BACKGROUND_LABEL],
182191
**post_tfm_kwargs,
183192
)
184193
else:

src/vak/models/frame_classification_model.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ def __init__(
130130
:const:`vak.common.constants.DEFAULT_BACKGROUND_LABEL`.
131131
"""
132132
super().__init__()
133-
134133
self.network = network
135134
self.loss = loss
136135
self.optimizer = optimizer
@@ -365,9 +364,15 @@ def validation_step(self, batch: tuple, batch_idx: int):
365364
class_preds_str = self.to_labels_eval(class_preds.cpu().numpy())
366365

367366
if self.post_tfm:
368-
class_preds_tfm = self.post_tfm(
369-
class_preds.cpu().numpy(),
370-
)
367+
if target_types == ("multi_frame_labels",):
368+
class_preds_tfm = self.post_tfm(
369+
class_preds.cpu().numpy(),
370+
)
371+
elif target_types == ("multi_frame_labels", "boundary_frame_labels"):
372+
class_preds_tfm = self.post_tfm(
373+
class_preds.cpu().numpy(),
374+
boundary_labels=boundary_preds.cpu().numpy(),
375+
)
371376
class_preds_tfm_str = self.to_labels_eval(class_preds_tfm)
372377
# convert back to tensor so we can compute accuracy
373378
class_preds_tfm = torch.from_numpy(class_preds_tfm).to(
@@ -395,8 +400,8 @@ def validation_step(self, batch: tuple, batch_idx: int):
395400
loss = self.loss(
396401
class_logits,
397402
boundary_logits,
398-
batch["multi_frame_labels"],
399-
batch["boundary_frame_labels"],
403+
target["multi_frame_labels"],
404+
target["boundary_frame_labels"],
400405
)
401406
if isinstance(loss, torch.Tensor):
402407
self.log(
@@ -435,7 +440,7 @@ def validation_step(self, batch: tuple, batch_idx: int):
435440
)
436441
else:
437442
self.log(
438-
f"val_{metric_name}",
443+
f"val_multi_{metric_name}",
439444
metric_callable(
440445
class_preds, target["multi_frame_labels"]
441446
),

src/vak/predict/frame_classification.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,11 @@ def predict_with_frame_classification_model(
468468
annot_path=annot_csv_path.name,
469469
)
470470
annots.append(annot)
471-
472-
if all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
471+
if len(annots) < 1:
472+
# catch edge case where nothing was predicted
473+
# FIXME: this should have columns that match GenericSeq
474+
pd.DataFrame.from_records([]).to_csv(annot_csv_path)
475+
elif all([isinstance(annot, crowsetta.Annotation) for annot in annots]):
473476
generic_seq = crowsetta.formats.seq.GenericSeq(annots=annots)
474477
generic_seq.to_file(annot_path=annot_csv_path)
475478
elif all([isinstance(annot, AnnotationDataFrame) for annot in annots]):

src/vak/train/frame_classification.py

+108-9
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
import pathlib
99
import shutil
1010

11+
import lightning
1112
import joblib
1213
import pandas as pd
1314
import torch.utils.data
1415

1516
from .. import datapipes, datasets, models, transforms
1617
from ..common import validators
17-
from ..common.trainer import get_default_trainer
1818
from ..datapipes.frame_classification import InferDatapipe, TrainDatapipe
1919

2020
logger = logging.getLogger(__name__)
@@ -25,6 +25,92 @@ def get_split_dur(df: pd.DataFrame, split: str) -> float:
2525
return df[df["split"] == split]["duration"].sum()
2626

2727

28+
def get_train_callbacks(
29+
ckpt_root: str | pathlib.Path,
30+
ckpt_step: int,
31+
patience: int,
32+
checkpoint_monitor: str = "val_acc",
33+
early_stopping_monitor: str = "val_acc",
34+
early_stopping_mode: str = "max",
35+
) -> list[lightning.pytorch.callbacks.Callback]:
36+
ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
37+
dirpath=ckpt_root,
38+
filename="checkpoint",
39+
every_n_train_steps=ckpt_step,
40+
save_last=True,
41+
verbose=True,
42+
)
43+
ckpt_callback.CHECKPOINT_NAME_LAST = "checkpoint"
44+
ckpt_callback.FILE_EXTENSION = ".pt"
45+
46+
val_ckpt_callback = lightning.pytorch.callbacks.ModelCheckpoint(
47+
monitor=checkpoint_monitor,
48+
dirpath=ckpt_root,
49+
save_top_k=1,
50+
mode="max",
51+
filename="max-val-acc-checkpoint",
52+
auto_insert_metric_name=False,
53+
verbose=True,
54+
)
55+
val_ckpt_callback.FILE_EXTENSION = ".pt"
56+
57+
early_stopping = lightning.pytorch.callbacks.EarlyStopping(
58+
mode=early_stopping_mode,
59+
monitor=early_stopping_monitor,
60+
patience=patience,
61+
verbose=True,
62+
)
63+
64+
return [ckpt_callback, val_ckpt_callback, early_stopping]
65+
66+
67+
def get_trainer(
68+
accelerator: str,
69+
devices: int | list[int],
70+
max_steps: int,
71+
log_save_dir: str | pathlib.Path,
72+
val_step: int,
73+
callback_kwargs: dict | None = None,
74+
) -> lightning.pytorch.Trainer:
75+
"""Returns an instance of :class:`lightning.pytorch.Trainer`
76+
with a default set of callbacks.
77+
78+
Used by :func:`vak.train.frame_classification`.
79+
The default set of callbacks is provided by
80+
:func:`get_default_train_callbacks`.
81+
82+
Parameters
83+
----------
84+
accelerator : str
85+
devices : int, list of int
86+
max_steps : int
87+
log_save_dir : str, pathlib.Path
88+
val_step : int
89+
default_callback_kwargs : dict, optional
90+
91+
Returns
92+
-------
93+
trainer : lightning.pytorch.Trainer
94+
95+
"""
96+
if callback_kwargs:
97+
callbacks = get_train_callbacks(**callback_kwargs)
98+
else:
99+
callbacks = None
100+
101+
logger = lightning.pytorch.loggers.TensorBoardLogger(save_dir=log_save_dir)
102+
103+
trainer = lightning.pytorch.Trainer(
104+
accelerator=accelerator,
105+
devices=devices,
106+
callbacks=callbacks,
107+
val_check_interval=val_step,
108+
max_steps=max_steps,
109+
logger=logger,
110+
)
111+
return trainer
112+
113+
28114
def train_frame_classification_model(
29115
model_config: dict,
30116
dataset_config: dict,
@@ -245,8 +331,9 @@ def train_frame_classification_model(
245331
dataset_config,
246332
split="train",
247333
)
334+
frame_dur = train_dataset.frame_dur
248335
logger.info(
249-
f"Duration of a frame in dataset, in seconds: {train_dataset.frame_dur}",
336+
f"Duration of a frame in dataset, in seconds: {frame_dur}",
250337
)
251338
# copy labelmap from dataset to new results_path
252339
labelmap = train_dataset.labelmap
@@ -334,18 +421,30 @@ def train_frame_classification_model(
334421
ckpt_root.mkdir()
335422
logger.info(f"training {model_name}")
336423
max_steps = num_epochs * len(train_loader)
337-
default_callback_kwargs = {
338-
"ckpt_root": ckpt_root,
339-
"ckpt_step": ckpt_step,
340-
"patience": patience,
341-
}
342-
trainer = get_default_trainer(
424+
if isinstance(dataset_config["params"]["target_type"], list) and all([isinstance(target_type, str) for target_type in dataset_config["params"]["target_type"]]):
425+
multiple_targets = True
426+
elif isinstance(dataset_config["params"]["target_type"], str):
427+
multiple_targets = False
428+
else:
429+
raise ValueError(
430+
f'Invalid value for dataset_config["params"]["target_type"]: {dataset_config["params"]["target_type"], list}'
431+
)
432+
433+
callback_kwargs = dict(
434+
ckpt_root=ckpt_root,
435+
ckpt_step=ckpt_step,
436+
patience=patience,
437+
checkpoint_monitor="val_multi_acc" if multiple_targets else "val_acc",
438+
early_stopping_monitor="val_multi_acc" if multiple_targets else "val_acc",
439+
early_stopping_mode="max",
440+
)
441+
trainer = get_trainer(
343442
accelerator=trainer_config["accelerator"],
344443
devices=trainer_config["devices"],
345444
max_steps=max_steps,
346445
log_save_dir=results_model_root,
347446
val_step=val_step,
348-
default_callback_kwargs=default_callback_kwargs,
447+
callback_kwargs=callback_kwargs,
349448
)
350449
train_time_start = datetime.datetime.now()
351450
logger.info(f"Training start time: {train_time_start.isoformat()}")

src/vak/transforms/frame_labels/functional.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,17 @@ def boundary_inds_from_boundary_labels(
401401
If ``True``, and the first index of ``boundary_labels`` is not classified as a boundary,
402402
force it to be a boundary.
403403
"""
404+
boundary_labels = row_or_1d(boundary_labels)
404405
boundary_inds = np.nonzero(boundary_labels)[0]
405406

406-
if boundary_inds[0] != 0 and force_boundary_first_ind:
407-
# force there to be a boundary at index 0
408-
np.insert(boundary_inds, 0, 0)
407+
if force_boundary_first_ind:
408+
if len(boundary_inds) == 0:
409+
# handle edge case where no boundaries were predicted
410+
boundary_inds = np.array([0]) # replace with a single boundary, at index 0
411+
else:
412+
if boundary_inds[0] != 0:
413+
# force there to be a boundary at index 0
414+
np.insert(boundary_inds, 0, 0)
409415

410416
return boundary_inds
411417

@@ -531,6 +537,8 @@ def postprocess(
531537
Vector of frame labels after post-processing is applied.
532538
"""
533539
frame_labels = row_or_1d(frame_labels)
540+
if boundary_labels is not None:
541+
boundary_labels = row_or_1d(boundary_labels)
534542

535543
# handle the case when all time bins are predicted to be unlabeled
536544
# see https://github.com/NickleDave/vak/issues/383

0 commit comments

Comments
 (0)