From ad38de6dde1c0647cac11786c900516e7df931c7 Mon Sep 17 00:00:00 2001 From: Emily Miller Date: Mon, 13 Dec 2021 18:00:54 -0800 Subject: [PATCH] error if not enough species, seed and then randomly allocate --- tests/test_config.py | 27 ++++++++++++++++++++------- zamba/models/config.py | 33 +++++++++++++++++---------------- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/tests/test_config.py b/tests/test_config.py index 9c714558..2d3aed47 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -203,18 +203,31 @@ def test_labels_with_invalid_split(labels_absolute_path): def test_labels_no_splits(labels_no_splits, tmp_path): - labels_three_videos = pd.read_csv(labels_no_splits).head(3) - # test with fewer videos and ensure we still get one of each + # ensure species are allocated to both sets + labels_four_videos = pd.read_csv(labels_no_splits).head(4) + labels_four_videos["label"] = ["gorilla"] * 2 + ["elephant"] * 2 _ = TrainConfig( data_dir=TEST_VIDEOS_DIR, - labels=labels_three_videos, + labels=labels_four_videos, save_dir=tmp_path, - split_proportions=dict(train=3, val=1, holdout=1), + split_proportions=dict(train=1, val=1, holdout=0), ) - assert set(pd.read_csv(tmp_path / "splits.csv").split.unique()) == set( - ["train", "val", "holdout"] - ) + assert ( + pd.read_csv(tmp_path / "splits.csv").split.values == ["train", "val", "train", "val"] + ).all() + + # remove the first row which puts antelope_duiker at 2 instead of 3 + labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :] + with pytest.raises(ValueError) as error: + TrainConfig( + data_dir=TEST_VIDEOS_DIR, + labels=labels_with_too_few_videos, + save_dir=tmp_path, + ) + assert ( + "Not all species have enough videos to allocate into the following splits: train, val, holdout. A minimumm of 3 videos per label is required. Found the following counts: {'antelope_duiker': 2}. Either remove these labels or add more videos." + ) == error.value.errors()[0]["msg"] def test_labels_split_proportions(labels_no_splits, tmp_path): diff --git a/zamba/models/config.py b/zamba/models/config.py index 1028d5af..39931a8b 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -7,7 +7,6 @@ import appdirs import ffmpeg from loguru import logger -import numpy as np import pandas as pd from pydantic import BaseModel from pydantic import DirectoryPath, FilePath, validator, root_validator @@ -502,28 +501,30 @@ def preprocess_labels(cls, values): ) expected_splits = [k for k, v in values["split_proportions"].items() if v > 0] + random.seed(SPLIT_SEED) - # check we have at least as many videos as we have splits - if len(expected_splits) > len(labels): + # check we have at least as many videos per species as we have splits + # labels are OHE at this point + num_videos_per_species = labels.filter(regex="species_").sum().to_dict() + too_few = { + k.split("species_", 1)[1]: v + for k, v in num_videos_per_species.items() + if v < len(expected_splits) + } + + if len(too_few) > 0: raise ValueError( - f"Not enough videos to allocate into {', '.join(expected_splits)} splits. Only {len(labels)} video(s) found." + f"Not all species have enough videos to allocate into the following splits: {', '.join(expected_splits)}. A minimumm of {len(expected_splits)} videos per label is required. Found the following counts: {too_few}. Either remove these labels or add more videos." ) - # seed splits by putting one video in each - labels["split"] = np.nan - labels.iloc[: len(expected_splits), -1] = expected_splits - - random.seed(SPLIT_SEED) - - # labels are OHE at this point - for c in labels.filter(regex="species").columns: - # within each species, allocate videos based on split proportions - species_df = labels[(labels[c] > 0) & labels.split.isnull()] + for c in labels.filter(regex="species_").columns: + species_df = labels[labels[c] > 0] - labels.loc[species_df.index, "split"] = random.choices( + # within each species, seed splits by putting one video in each set and then allocate videos based on split proportions + labels.loc[species_df.index, "split"] = expected_splits + random.choices( list(values["split_proportions"].keys()), weights=list(values["split_proportions"].values()), - k=len(species_df), + k=len(species_df) - len(expected_splits), ) logger.info(f"{labels.split.value_counts()}")