diff --git a/tests/test_load_video_frames.py b/tests/test_load_video_frames.py index 58c8a609..164083b9 100644 --- a/tests/test_load_video_frames.py +++ b/tests/test_load_video_frames.py @@ -10,11 +10,11 @@ from pydantic import BaseModel, ValidationError from zamba.data.video import ( - cached_load_video_frames, load_video_frames, - VideoLoaderConfig, MegadetectorLiteYoloXConfig, + VideoLoaderConfig, ) +from zamba.pytorch.dataloaders import FfmpegZambaVideoDataset from conftest import ASSETS_DIR, TEST_VIDEOS_DIR @@ -384,32 +384,41 @@ def test_load_video_frames(case: Case, video_metadata: Dict[str, Any]): assert video_shape[field] == value -def test_same_filename_new_kwargs(tmp_path): +def test_same_filename_new_kwargs(tmp_path, train_metadata): """Test that load_video_frames does not load the npz file if the params change.""" - # use first test video - test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0] cache = tmp_path / "test_cache" + # prep labels for one video + labels = ( + train_metadata[train_metadata.split == "train"] + .set_index("filepath") + .filter(regex="species") + .head(1) + ) + + def _generate_dataset(config): + """Return loaded video from FFmpegZambaVideoDataset.""" + return FfmpegZambaVideoDataset(annotations=labels, video_loader_config=config).__getitem__( + index=0 + )[0] + with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}): # confirm cache is set in environment variable assert os.environ["VIDEO_CACHE_DIR"] == str(cache) - first_load = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1)) - new_params_same_name = cached_load_video_frames( - filepath=test_vid, config=VideoLoaderConfig(fps=2) - ) + first_load = _generate_dataset(config=VideoLoaderConfig(fps=1)) + new_params_same_name = _generate_dataset(config=VideoLoaderConfig(fps=2)) assert first_load.shape != new_params_same_name.shape # check no params - first_load = cached_load_video_frames(filepath=test_vid) - assert first_load.shape != new_params_same_name.shape + no_params_same_name = _generate_dataset(config=None) + assert first_load.shape != new_params_same_name.shape != no_params_same_name.shape # multiple params in config - c1 = VideoLoaderConfig(scene_threshold=0.2) - c2 = VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2) - - first_load = cached_load_video_frames(filepath=test_vid, config=c1) - new_params_same_name = cached_load_video_frames(filepath=test_vid, config=c2) + first_load = _generate_dataset(config=VideoLoaderConfig(scene_threshold=0.2)) + new_params_same_name = _generate_dataset( + config=VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2) + ) assert first_load.shape != new_params_same_name.shape @@ -506,39 +515,55 @@ def test_validate_total_frames(): assert config.total_frames == 8 -def test_caching(tmp_path, caplog): +def test_caching(tmp_path, caplog, train_metadata): cache = tmp_path / "video_cache" - test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0] + + # prep labels for one video + labels = ( + train_metadata[train_metadata.split == "train"] + .set_index("filepath") + .filter(regex="species") + .head(1) + ) # no caching by default - _ = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1)) + _ = FfmpegZambaVideoDataset( + annotations=labels, + ).__getitem__(index=0) assert not cache.exists() # caching can be specifed in config - _ = cached_load_video_frames( - filepath=test_vid, config=VideoLoaderConfig(fps=1, cache_dir=cache) - ) + _ = FfmpegZambaVideoDataset( + annotations=labels, video_loader_config=VideoLoaderConfig(fps=1, cache_dir=cache) + ).__getitem__(index=0) + # one file in cache assert len([f for f in cache.rglob("*") if f.is_file()]) == 1 shutil.rmtree(cache) # or caching can be specified in environment variable with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}): - _ = cached_load_video_frames(filepath=test_vid) + _ = FfmpegZambaVideoDataset( + annotations=labels, + ).__getitem__(index=0) assert len([f for f in cache.rglob("*") if f.is_file()]) == 1 # changing cleanup in config does not prompt new hashing of videos with mock.patch.dict(os.environ, {"LOG_LEVEL": "DEBUG"}): - _ = cached_load_video_frames( - filepath=test_vid, config=VideoLoaderConfig(cleanup_cache=True) - ) + _ = FfmpegZambaVideoDataset( + annotations=labels, video_loader_config=VideoLoaderConfig(cleanup_cache=True) + ).__getitem__(index=0) + assert "Loading from cache" in caplog.text # if no config is passed, this is equivalent to specifying None/False in all non-cache related VLC params - no_config = cached_load_video_frames(filepath=test_vid, config=None) - config_with_nones = cached_load_video_frames( - filepath=test_vid, - config=VideoLoaderConfig( + no_config = FfmpegZambaVideoDataset(annotations=labels, video_loader_config=None).__getitem__( + index=0 + )[0] + + config_with_nones = FfmpegZambaVideoDataset( + annotations=labels, + video_loader_config=VideoLoaderConfig( crop_bottom_pixels=None, i_frames=False, scene_threshold=None, @@ -555,5 +580,6 @@ def test_caching(tmp_path, caplog): model_input_height=None, model_input_width=None, ), - ) + ).__getitem__(index=0)[0] + assert np.array_equal(no_config, config_with_nones) diff --git a/zamba/data/video.py b/zamba/data/video.py index 0daccbdb..1f798fb5 100644 --- a/zamba/data/video.py +++ b/zamba/data/video.py @@ -326,8 +326,8 @@ def validate_total_frames(cls, values): class npy_cache: - def __init__(self, path: Optional[Path] = None, cleanup: bool = False): - self.tmp_path = path + def __init__(self, cache_path: Optional[Path] = None, cleanup: bool = False): + self.cache_path = cache_path self.cleanup = cleanup def __call__(self, f): @@ -358,7 +358,7 @@ def _wrapped(*args, **kwargs): if isinstance(vid_path, S3Path): vid_path = AnyPath(vid_path.key) - npy_path = self.tmp_path / hash_str / vid_path.with_suffix(".npy") + npy_path = self.cache_path / hash_str / vid_path.with_suffix(".npy") # make parent directories since we're using absolute paths npy_path.parent.mkdir(parents=True, exist_ok=True) @@ -372,32 +372,24 @@ def _wrapped(*args, **kwargs): logger.debug(f"Wrote to cache {npy_path}: size {npy_path.stat().st_size}") return loaded_video - if self.tmp_path is not None: + if self.cache_path is not None: return _wrapped else: return f def __del__(self): - if hasattr(self, "tmp_path") and self.cleanup and self.tmp_path.exists(): - if self.tmp_path.parents[0] == tempfile.gettempdir(): - logger.info(f"Deleting cache dir {self.tmp_path}.") - rmtree(self.tmp_path) + if hasattr(self, "cache_path") and self.cleanup and self.cache_path.exists(): + if self.cache_path.parents[0] == tempfile.gettempdir(): + logger.info(f"Deleting cache dir {self.cache_path}.") + rmtree(self.cache_path) else: logger.warning( "Bravely refusing to delete directory that is not a subdirectory of the " "system temp directory. If you really want to delete, do so manually using:\n " - f"rm -r {self.tmp_path}" + f"rm -r {self.cache_path}" ) -def npy_cache_factory(path, callable, cleanup): - @npy_cache(path=path, cleanup=cleanup) - def decorated_callable(*args, **kwargs): - return callable(*args, **kwargs) - - return decorated_callable - - def load_video_frames( filepath: os.PathLike, config: Optional[VideoLoaderConfig] = None, @@ -503,21 +495,3 @@ def load_video_frames( arr = ensure_frame_number(arr, total_frames=config.total_frames) return arr - - -def cached_load_video_frames(filepath: os.PathLike, config: Optional[VideoLoaderConfig] = None): - """Loads frames from videos using fast ffmpeg commands and caches to .npy file - if config.cache_dir is not None. - - Args: - filepath (os.PathLike): Path to the video. - config (VideoLoaderConfig): Configuration for video loading. - """ - if config is None: - # get environment variable for cache if it exists - config = VideoLoaderConfig() - - decorated_load_video_frames = npy_cache_factory( - path=config.cache_dir, callable=load_video_frames, cleanup=config.cleanup_cache - ) - return decorated_load_video_frames(filepath=filepath, config=config) diff --git a/zamba/models/model_manager.py b/zamba/models/model_manager.py index dbfd5de9..35eb5b81 100644 --- a/zamba/models/model_manager.py +++ b/zamba/models/model_manager.py @@ -81,7 +81,7 @@ def instantiate_model( hparams = yaml.safe_load(f) else: - if not Path(checkpoint).exists(): + if not (model_cache_dir / checkpoint).exists(): logger.info("Downloading weights for model.") checkpoint = download_weights( filename=str(checkpoint), diff --git a/zamba/pytorch/dataloaders.py b/zamba/pytorch/dataloaders.py index fde47611..ed9008bd 100644 --- a/zamba/pytorch/dataloaders.py +++ b/zamba/pytorch/dataloaders.py @@ -11,7 +11,7 @@ from torchvision.datasets.vision import VisionDataset import torchvision.transforms.transforms -from zamba.data.video import cached_load_video_frames, VideoLoaderConfig +from zamba.data.video import npy_cache, load_video_frames, VideoLoaderConfig def get_datasets( @@ -87,6 +87,11 @@ def __init__( self.targets = annotations self.transform = transform + + # get environment variable for cache if it exists + if video_loader_config is None: + video_loader_config = VideoLoaderConfig() + self.video_loader_config = video_loader_config super().__init__(root=None, transform=transform) @@ -96,6 +101,11 @@ def __len__(self): def __getitem__(self, index: int): try: + cached_load_video_frames = npy_cache( + cache_path=self.video_loader_config.cache_dir, + cleanup=self.video_loader_config.cleanup_cache, + )(load_video_frames) + video = cached_load_video_frames( filepath=self.video_paths[index], config=self.video_loader_config )