Skip to content

Commit

Permalink
Merge pull request #215 from drivendataorg/binary-classifier
Browse files Browse the repository at this point in the history
Support binary case
  • Loading branch information
AllenDowney authored Aug 23, 2022
2 parents 1d273dd + 22b3ff4 commit ebc698e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 59 deletions.
5 changes: 2 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,8 @@ def test_labels_no_splits(labels_no_splits, tmp_path):
split_proportions=dict(train=1, val=1, holdout=0),
)

assert (
pd.read_csv(tmp_path / "splits.csv").split.values == ["train", "val", "train", "val"]
).all()
split = pd.read_csv(tmp_path / "splits.csv")["split"].values
assert (split == ["train", "val", "train", "val"]).all()

# remove the first row which puts antelope_duiker at 2 instead of 3
labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :]
Expand Down
12 changes: 11 additions & 1 deletion tests/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,17 @@ def test_save_metrics_less_than_two_classes(
}
)

assert metrics.keys() == metric_names
removed_in_binary_case = {
"species/test_precision/B",
"species/test_recall/B",
"species/test_accuracy/B",
"species/test_f1/B",
"species/val_precision/B",
"species/val_recall/B",
"species/val_accuracy/B",
"species/val_f1/B",
}
assert metrics.keys() == metric_names - removed_in_binary_case


def test_save_configuration(dummy_trainer):
Expand Down
137 changes: 82 additions & 55 deletions zamba/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,78 +481,105 @@ def validate_filepaths_and_labels(cls, values):

@root_validator(skip_on_failure=True)
def preprocess_labels(cls, values):
"""One hot encode, add splits, and check for binary case.
Replaces values['labels'] with modified DataFrame.
Args:
values: dictionary containing 'labels' and other config info
"""
logger.info("Preprocessing labels into one hot encoded labels with one row per video.")
labels = values["labels"]

# one hot encode collapse to one row per video
labels = (
pd.get_dummies(
values["labels"].rename(columns={"label": "species"}), columns=["species"]
)
pd.get_dummies(labels.rename(columns={"label": "species"}), columns=["species"])
.groupby("filepath")
.max()
)

# if no "split" column, set up train, val, and holdout split
if "split" not in labels.columns:
logger.info(
f"Dividing videos into train, val, and holdout sets using the following split proportions: {values['split_proportions']}."
)
make_split(labels, values)

# use site info if we have it
if "site" in labels.columns:
logger.info("Using provided 'site' column to do a site-specific split")
labels["split"] = create_site_specific_splits(
labels["site"], proportions=values["split_proportions"]
)
else:
# otherwise randomly allocate
logger.info(
"No 'site' column found so videos for each species will be randomly allocated across splits using provided split proportions."
)

expected_splits = [k for k, v in values["split_proportions"].items() if v > 0]
random.seed(SPLIT_SEED)

# check we have at least as many videos per species as we have splits
# labels are OHE at this point
num_videos_per_species = labels.filter(regex="species_").sum().to_dict()
too_few = {
k.split("species_", 1)[1]: v
for k, v in num_videos_per_species.items()
if v < len(expected_splits)
}

if len(too_few) > 0:
raise ValueError(
f"Not all species have enough videos to allocate into the following splits: {', '.join(expected_splits)}. A minimum of {len(expected_splits)} videos per label is required. Found the following counts: {too_few}. Either remove these labels or add more videos."
)

for c in labels.filter(regex="species_").columns:
species_df = labels[labels[c] > 0]

# within each species, seed splits by putting one video in each set and then allocate videos based on split proportions
labels.loc[species_df.index, "split"] = expected_splits + random.choices(
list(values["split_proportions"].keys()),
weights=list(values["split_proportions"].values()),
k=len(species_df) - len(expected_splits),
)

logger.info(f"{labels.split.value_counts()}")
logger.info(
f"Writing out split information to {values['save_dir'] / 'splits.csv'}."
)
# if there are only two species columns and every video belongs to one of them,
# drop the second species column so the problem is treated as a binary classification
species_cols = labels.filter(regex="species_").columns
sums = labels[species_cols].sum(axis=1)

# create the directory to save if we need to
values["save_dir"].mkdir(parents=True, exist_ok=True)

labels.reset_index()[["filepath", "split"]].drop_duplicates().to_csv(
values["save_dir"] / "splits.csv", index=False
)
if len(species_cols) == 2 and (sums == 1).all():
logger.warning(
f"Binary case detected so only one species column will be kept. Output will be the binary case of {species_cols[0]}."
)
labels = labels.drop(columns=species_cols[1])

# filepath becomes column instead of index
values["labels"] = labels.reset_index()
return values


def make_split(labels, values):
"""Add a split column to `labels`.
Args:
labels: DataFrame with one row per video
values: dictionary with config info
"""
logger.info(
f"Dividing videos into train, val, and holdout sets using the following split proportions: {values['split_proportions']}."
)

# use site info if we have it
if "site" in labels.columns:
logger.info("Using provided 'site' column to do a site-specific split")
labels["split"] = create_site_specific_splits(
labels["site"], proportions=values["split_proportions"]
)
else:
# otherwise randomly allocate
logger.info(
"No 'site' column found so videos for each species will be randomly allocated across splits using provided split proportions."
)

expected_splits = [k for k, v in values["split_proportions"].items() if v > 0]
random.seed(SPLIT_SEED)

# check we have at least as many videos per species as we have splits
# labels are OHE at this point
num_videos_per_species = labels.filter(regex="species_").sum().to_dict()
too_few = {
k.split("species_", 1)[1]: v
for k, v in num_videos_per_species.items()
if v < len(expected_splits)
}

if len(too_few) > 0:
raise ValueError(
f"Not all species have enough videos to allocate into the following splits: {', '.join(expected_splits)}. A minimum of {len(expected_splits)} videos per label is required. Found the following counts: {too_few}. Either remove these labels or add more videos."
)

for c in labels.filter(regex="species_").columns:
species_df = labels[labels[c] > 0]

# within each species, seed splits by putting one video in each set and then allocate videos based on split proportions
labels.loc[species_df.index, "split"] = expected_splits + random.choices(
list(values["split_proportions"].keys()),
weights=list(values["split_proportions"].values()),
k=len(species_df) - len(expected_splits),
)

logger.info(f"{labels.split.value_counts()}")

# write splits.csv
filename = values["save_dir"] / "splits.csv"
logger.info(f"Writing out split information to {filename}.")

# create the directory to save if we need to
values["save_dir"].mkdir(parents=True, exist_ok=True)

labels.reset_index()[["filepath", "split"]].drop_duplicates().to_csv(filename, index=False)


class PredictConfig(ZambaBaseModel):
"""
Configuration for using a model for inference.
Expand Down

0 comments on commit ebc698e

Please sign in to comment.