Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure videos are allocated into all specified splits #169

Merged
merged 10 commits into from
Dec 14, 2021
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def forward(self, x, *args, **kwargs):
class DummyTrainConfig(TrainConfig):
# let model name be "dummy" without causing errors
model_name: str
batch_size = 1
max_epochs = 1
model_name = "dummy"
skip_load_validation = True
auto_lr_find = False


@pytest.fixture(scope="session")
Expand Down
14 changes: 12 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,18 @@ def test_labels_with_invalid_split(labels_absolute_path):


def test_labels_no_splits(labels_no_splits, tmp_path):
config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_no_splits, save_dir=tmp_path)
assert set(config.labels.split.unique()) == set(("holdout", "train", "val"))
labels_three_videos = pd.read_csv(labels_no_splits).head(3)
# test with fewer videos and ensure we still get one of each
_ = TrainConfig(
data_dir=TEST_VIDEOS_DIR,
labels=labels_three_videos,
save_dir=tmp_path,
split_proportions=dict(train=3, val=1, holdout=1),
)

assert set(pd.read_csv(tmp_path / "splits.csv").split.unique()) == set(
["train", "val", "holdout"]
)


def test_labels_split_proportions(labels_no_splits, tmp_path):
Expand Down
24 changes: 14 additions & 10 deletions tests/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ def test_model_manager(dummy_trainer):
assert not (dummy_trainer.model.model[3].weight == 0).all()


def test_no_early_stopping(
labels_absolute_path, tmp_path, dummy_checkpoint, dummy_video_loader_config
):
config = DummyTrainConfig(
labels=labels_absolute_path,
data_dir=TEST_VIDEOS_DIR,
checkpoint=dummy_checkpoint,
early_stopping_config=None,
save_dir=tmp_path / "my_model",
num_workers=1,
)
train_model(train_config=config, video_loader_config=dummy_video_loader_config)


def test_save_checkpoint(dummy_trained_model_checkpoint):
checkpoint = torch.load(dummy_trained_model_checkpoint)

Expand Down Expand Up @@ -76,14 +90,9 @@ def test_save_metrics_less_than_two_classes(
train_config=DummyTrainConfig(
labels=labels,
data_dir=TEST_VIDEOS_DIR,
model_name="dummy",
checkpoint=dummy_checkpoint,
max_epochs=1,
batch_size=1,
auto_lr_find=False,
num_workers=2,
save_dir=tmp_path / "my_model",
skip_load_validation=True,
),
video_loader_config=dummy_video_loader_config,
)
Expand Down Expand Up @@ -136,14 +145,9 @@ def test_train_save_dir_overwrite(
config = DummyTrainConfig(
labels=labels_absolute_path,
data_dir=TEST_VIDEOS_DIR,
model_name="dummy",
checkpoint=dummy_checkpoint,
save_dir=tmp_path / "my_model",
skip_load_validation=True,
overwrite=True,
max_epochs=1,
batch_size=1,
auto_lr_find=False,
num_workers=1,
)

Expand Down
28 changes: 19 additions & 9 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import appdirs
import ffmpeg
from loguru import logger
import numpy as np
import pandas as pd
from pydantic import BaseModel
from pydantic import DirectoryPath, FilePath, validator, root_validator
Expand Down Expand Up @@ -495,19 +496,28 @@ def preprocess_labels(cls, values):
labels["site"], proportions=values["split_proportions"]
)
else:
logger.info(
"No 'site' column found so videos will be randomly allocated to splits."
)
# otherwise randomly allocate
random.seed(SPLIT_SEED)
labels["split"] = random.choices(
list(values["split_proportions"].keys()),
weights=list(values["split_proportions"].values()),
k=len(labels),
logger.info(
"No 'site' column found so videos will be randomly allocated using split proportions."
)

expected_labels = [k for k, v in values["split_proportions"].items() if v > 0]
ejm714 marked this conversation as resolved.
Show resolved Hide resolved
labels["split"] = ""
seed = SPLIT_SEED

while len(np.setdiff1d(expected_labels, labels.split.unique())):
ejm714 marked this conversation as resolved.
Show resolved Hide resolved

random.seed(seed)
labels["split"] = random.choices(
list(values["split_proportions"].keys()),
weights=list(values["split_proportions"].values()),
k=len(labels),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has one weird edge case that is likely rare, so just worth filing an issue for:

v0.mp4, antelope
v1.mp4, antelope  # v1 assigned test by antelope grouping
v1.mp4, cow       # subsequently v1 assigned train by cow grouping
v2.mp4, antelope
v4.mp4, cow
v5.mp4, cow

# test set is now missing antelope


seed += 1

logger.info(
f"Writing out split information to {values['save_dir'] / 'splits.csv'}."
f"Writing out split information to {values['save_dir'] / 'splits.csv'}. Used random seed {seed}."
)

# create the directory to save if we need to.
Expand Down
8 changes: 6 additions & 2 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,12 @@ def train_model(
model_checkpoint = ModelCheckpoint(
dirpath=logging_and_save_dir,
filename=train_config.model_name,
monitor=train_config.early_stopping_config.monitor,
mode=train_config.early_stopping_config.mode,
monitor=train_config.early_stopping_config.monitor
if train_config.early_stopping_config is not None
else None,
mode=train_config.early_stopping_config.mode
if train_config.early_stopping_config is not None
else "min",
)

callbacks = [model_checkpoint]
Expand Down