diff --git a/tests/assets/labels.csv b/tests/assets/labels.csv index 5aef26ad..cfc9c6c7 100644 --- a/tests/assets/labels.csv +++ b/tests/assets/labels.csv @@ -1,5 +1,5 @@ filepath,label,split -data/raw/savanna/Grumeti_Tanzania/K38_check3/09190048_Hyena.AVI,gorilla,train +data/raw/savanna/Grumeti_Tanzania/K38_check3/09190048_Hyena.AVI,antelope_duiker,train data/raw/savanna/Grumeti_Tanzania/G41_check2/09100029_Eland.MP4,antelope_duiker,train data/raw/savanna/Gorongosa_Mozambique/2017 Videos/F13_Cam027/Baboon/07120049.AVI,elephant,train data/raw/goualougo_2013/chimp_MPI_FID_2013/MPI_FID_31_Abel/06-May-2013/FID_31_Abel_2013-5-6_0027.AVI,gorilla,train @@ -15,6 +15,6 @@ data/raw/chimpandsee/Kay_H12/Kay_vid2_0803042_1459433_20130528/PICT0265.AVI,gori data/raw/chimpandsee/Kor_C5/Kor_vid6_0485466_0567428_20141120/EK000019.AVI,gorilla,val data/raw/chimpandsee/Trip_5/card 1.2_location/PICT0074.ASF,gorilla,val data/raw/chimpandsee/Sap_D1/Sap_vid24_0524356_0589414_20110915/PICT0021.ASF,gorilla,holdout -data/raw/chimpandsee/Bwi_A3/bwi_vid7_806110_9885227_20120911/PICT0146.AVI,gorilla,holdout +data/raw/chimpandsee/Bwi_A3/bwi_vid7_806110_9885227_20120911/PICT0146.AVI,elephant,holdout data/raw/chimpandsee/Gas_D4/Gas_Vid20_0786029_0808836_20130725/EK000018.AVI,elephant,holdout data/raw/salonga/MISSION_3/Team_Pedro/Trans_322__Cam_1_488454_9765353_Date_24-11-2017/01110103.AVI,antelope_duiker,holdout diff --git a/tests/conftest.py b/tests/conftest.py index 79e0c46f..8549bc21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -71,6 +71,11 @@ def forward(self, x, *args, **kwargs): class DummyTrainConfig(TrainConfig): # let model name be "dummy" without causing errors model_name: str + batch_size = 1 + max_epochs = 1 + model_name = "dummy" + skip_load_validation = True + auto_lr_find = False @pytest.fixture(scope="session") diff --git a/tests/test_config.py b/tests/test_config.py index 5d4cb0d6..2d3aed47 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -203,8 +203,31 @@ def test_labels_with_invalid_split(labels_absolute_path): def test_labels_no_splits(labels_no_splits, tmp_path): - config = TrainConfig(data_dir=TEST_VIDEOS_DIR, labels=labels_no_splits, save_dir=tmp_path) - assert set(config.labels.split.unique()) == set(("holdout", "train", "val")) + # ensure species are allocated to both sets + labels_four_videos = pd.read_csv(labels_no_splits).head(4) + labels_four_videos["label"] = ["gorilla"] * 2 + ["elephant"] * 2 + _ = TrainConfig( + data_dir=TEST_VIDEOS_DIR, + labels=labels_four_videos, + save_dir=tmp_path, + split_proportions=dict(train=1, val=1, holdout=0), + ) + + assert ( + pd.read_csv(tmp_path / "splits.csv").split.values == ["train", "val", "train", "val"] + ).all() + + # remove the first row which puts antelope_duiker at 2 instead of 3 + labels_with_too_few_videos = pd.read_csv(labels_no_splits).iloc[1:, :] + with pytest.raises(ValueError) as error: + TrainConfig( + data_dir=TEST_VIDEOS_DIR, + labels=labels_with_too_few_videos, + save_dir=tmp_path, + ) + assert ( + "Not all species have enough videos to allocate into the following splits: train, val, holdout. A minimumm of 3 videos per label is required. Found the following counts: {'antelope_duiker': 2}. Either remove these labels or add more videos." + ) == error.value.errors()[0]["msg"] def test_labels_split_proportions(labels_no_splits, tmp_path): @@ -214,7 +237,7 @@ def test_labels_split_proportions(labels_no_splits, tmp_path): split_proportions={"a": 3, "b": 1}, save_dir=tmp_path, ) - assert config.labels.split.value_counts().to_dict() == {"a": 14, "b": 5} + assert config.labels.split.value_counts().to_dict() == {"a": 13, "b": 6} def test_from_scratch(labels_absolute_path): diff --git a/tests/test_model_manager.py b/tests/test_model_manager.py index 682a032d..8539cbda 100644 --- a/tests/test_model_manager.py +++ b/tests/test_model_manager.py @@ -17,6 +17,20 @@ def test_model_manager(dummy_trainer): assert not (dummy_trainer.model.model[3].weight == 0).all() +def test_no_early_stopping( + labels_absolute_path, tmp_path, dummy_checkpoint, dummy_video_loader_config +): + config = DummyTrainConfig( + labels=labels_absolute_path, + data_dir=TEST_VIDEOS_DIR, + checkpoint=dummy_checkpoint, + early_stopping_config=None, + save_dir=tmp_path / "my_model", + num_workers=1, + ) + train_model(train_config=config, video_loader_config=dummy_video_loader_config) + + def test_save_checkpoint(dummy_trained_model_checkpoint): checkpoint = torch.load(dummy_trained_model_checkpoint) @@ -76,14 +90,9 @@ def test_save_metrics_less_than_two_classes( train_config=DummyTrainConfig( labels=labels, data_dir=TEST_VIDEOS_DIR, - model_name="dummy", checkpoint=dummy_checkpoint, - max_epochs=1, - batch_size=1, - auto_lr_find=False, num_workers=2, save_dir=tmp_path / "my_model", - skip_load_validation=True, ), video_loader_config=dummy_video_loader_config, ) @@ -136,14 +145,9 @@ def test_train_save_dir_overwrite( config = DummyTrainConfig( labels=labels_absolute_path, data_dir=TEST_VIDEOS_DIR, - model_name="dummy", checkpoint=dummy_checkpoint, save_dir=tmp_path / "my_model", - skip_load_validation=True, overwrite=True, - max_epochs=1, - batch_size=1, - auto_lr_find=False, num_workers=1, ) diff --git a/zamba/models/config.py b/zamba/models/config.py index d12f13d0..39931a8b 100644 --- a/zamba/models/config.py +++ b/zamba/models/config.py @@ -495,22 +495,44 @@ def preprocess_labels(cls, values): labels["site"], proportions=values["split_proportions"] ) else: + # otherwise randomly allocate logger.info( - "No 'site' column found so videos will be randomly allocated to splits." + "No 'site' column found so videos for each species will be randomly allocated across splits using provided split proportions." ) - # otherwise randomly allocate + + expected_splits = [k for k, v in values["split_proportions"].items() if v > 0] random.seed(SPLIT_SEED) - labels["split"] = random.choices( - list(values["split_proportions"].keys()), - weights=list(values["split_proportions"].values()), - k=len(labels), - ) + # check we have at least as many videos per species as we have splits + # labels are OHE at this point + num_videos_per_species = labels.filter(regex="species_").sum().to_dict() + too_few = { + k.split("species_", 1)[1]: v + for k, v in num_videos_per_species.items() + if v < len(expected_splits) + } + + if len(too_few) > 0: + raise ValueError( + f"Not all species have enough videos to allocate into the following splits: {', '.join(expected_splits)}. A minimumm of {len(expected_splits)} videos per label is required. Found the following counts: {too_few}. Either remove these labels or add more videos." + ) + + for c in labels.filter(regex="species_").columns: + species_df = labels[labels[c] > 0] + + # within each species, seed splits by putting one video in each set and then allocate videos based on split proportions + labels.loc[species_df.index, "split"] = expected_splits + random.choices( + list(values["split_proportions"].keys()), + weights=list(values["split_proportions"].values()), + k=len(species_df) - len(expected_splits), + ) + + logger.info(f"{labels.split.value_counts()}") logger.info( f"Writing out split information to {values['save_dir'] / 'splits.csv'}." ) - # create the directory to save if we need to. + # create the directory to save if we need to values["save_dir"].mkdir(parents=True, exist_ok=True) labels.reset_index()[["filepath", "split"]].drop_duplicates().to_csv( diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index 11073859..3012711d 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -245,8 +245,12 @@ def train_model( model_checkpoint = ModelCheckpoint( dirpath=logging_and_save_dir, filename=train_config.model_name, - monitor=train_config.early_stopping_config.monitor, - mode=train_config.early_stopping_config.mode, + monitor=train_config.early_stopping_config.monitor + if train_config.early_stopping_config is not None + else None, + mode=train_config.early_stopping_config.mode + if train_config.early_stopping_config is not None + else "min", ) callbacks = [model_checkpoint]