Skip to content

Commit

Permalink
error if not enough species, seed and then randomly allocate
Browse files Browse the repository at this point in the history
  • Loading branch information
ejm714 committed Dec 14, 2021
1 parent 9ab1f9d commit ad38de6
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
27 changes: 20 additions & 7 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,31 @@ def test_labels_with_invalid_split(labels_absolute_path):


def test_labels_no_splits(labels_no_splits, tmp_path):
labels_three_videos = pd.read_csv(labels_no_splits).head(3)
# test with fewer videos and ensure we still get one of each
# ensure species are allocated to both sets
labels_four_videos = pd.read_csv(labels_no_splits).head(4)
labels_four_videos["label"] = ["gorilla"] * 2 + ["elephant"] * 2
_ = TrainConfig(
data_dir=TEST_VIDEOS_DIR,
labels=labels_three_videos,
labels=labels_four_videos,
save_dir=tmp_path,
split_proportions=dict(train=3, val=1, holdout=1),
split_proportions=dict(train=1, val=1, holdout=0),
)

assert set(pd.read_csv(tmp_path / "splits.csv").split.unique()) == set(
["train", "val", "holdout"]
)
assert (
pd.read_csv(tmp_path / "splits.csv").split.values == ["train", "val", "train", "val"]
).all()

# remove the first row which puts antelope_duiker at 2 instead of 3
labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :]
with pytest.raises(ValueError) as error:
TrainConfig(
data_dir=TEST_VIDEOS_DIR,
labels=labels_with_too_few_videos,
save_dir=tmp_path,
)
assert (
"Not all species have enough videos to allocate into the following splits: train, val, holdout. A minimumm of 3 videos per label is required. Found the following counts: {'antelope_duiker': 2}. Either remove these labels or add more videos."
) == error.value.errors()[0]["msg"]


def test_labels_split_proportions(labels_no_splits, tmp_path):
Expand Down
33 changes: 17 additions & 16 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import appdirs
import ffmpeg
from loguru import logger
import numpy as np
import pandas as pd
from pydantic import BaseModel
from pydantic import DirectoryPath, FilePath, validator, root_validator
Expand Down Expand Up @@ -502,28 +501,30 @@ def preprocess_labels(cls, values):
)

expected_splits = [k for k, v in values["split_proportions"].items() if v > 0]
random.seed(SPLIT_SEED)

# check we have at least as many videos as we have splits
if len(expected_splits) > len(labels):
# check we have at least as many videos per species as we have splits
# labels are OHE at this point
num_videos_per_species = labels.filter(regex="species_").sum().to_dict()
too_few = {
k.split("species_", 1)[1]: v
for k, v in num_videos_per_species.items()
if v < len(expected_splits)
}

if len(too_few) > 0:
raise ValueError(
f"Not enough videos to allocate into {', '.join(expected_splits)} splits. Only {len(labels)} video(s) found."
f"Not all species have enough videos to allocate into the following splits: {', '.join(expected_splits)}. A minimumm of {len(expected_splits)} videos per label is required. Found the following counts: {too_few}. Either remove these labels or add more videos."
)

# seed splits by putting one video in each
labels["split"] = np.nan
labels.iloc[: len(expected_splits), -1] = expected_splits

random.seed(SPLIT_SEED)

# labels are OHE at this point
for c in labels.filter(regex="species").columns:
# within each species, allocate videos based on split proportions
species_df = labels[(labels[c] > 0) & labels.split.isnull()]
for c in labels.filter(regex="species_").columns:
species_df = labels[labels[c] > 0]

labels.loc[species_df.index, "split"] = random.choices(
# within each species, seed splits by putting one video in each set and then allocate videos based on split proportions
labels.loc[species_df.index, "split"] = expected_splits + random.choices(
list(values["split_proportions"].keys()),
weights=list(values["split_proportions"].values()),
k=len(species_df),
k=len(species_df) - len(expected_splits),
)

logger.info(f"{labels.split.value_counts()}")
Expand Down

0 comments on commit ad38de6

Please sign in to comment.