Skip to content

Commit

Permalink
Simplify caching (#145)
Browse files Browse the repository at this point in the history
* simplify caching

* nest function

* simplify

* move

* alphabetize

* lint

* rename to cache path

* fix dataset
  • Loading branch information
ejm714 authored Oct 22, 2021
1 parent fc421df commit c41ae7e
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 68 deletions.
88 changes: 57 additions & 31 deletions tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pydantic import BaseModel, ValidationError

from zamba.data.video import (
cached_load_video_frames,
load_video_frames,
VideoLoaderConfig,
MegadetectorLiteYoloXConfig,
VideoLoaderConfig,
)
from zamba.pytorch.dataloaders import FfmpegZambaVideoDataset

from conftest import ASSETS_DIR, TEST_VIDEOS_DIR

Expand Down Expand Up @@ -384,32 +384,41 @@ def test_load_video_frames(case: Case, video_metadata: Dict[str, Any]):
assert video_shape[field] == value


def test_same_filename_new_kwargs(tmp_path):
def test_same_filename_new_kwargs(tmp_path, train_metadata):
"""Test that load_video_frames does not load the npz file if the params change."""
# use first test video
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]
cache = tmp_path / "test_cache"

# prep labels for one video
labels = (
train_metadata[train_metadata.split == "train"]
.set_index("filepath")
.filter(regex="species")
.head(1)
)

def _generate_dataset(config):
"""Return loaded video from FFmpegZambaVideoDataset."""
return FfmpegZambaVideoDataset(annotations=labels, video_loader_config=config).__getitem__(
index=0
)[0]

with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
# confirm cache is set in environment variable
assert os.environ["VIDEO_CACHE_DIR"] == str(cache)

first_load = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
new_params_same_name = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=2)
)
first_load = _generate_dataset(config=VideoLoaderConfig(fps=1))
new_params_same_name = _generate_dataset(config=VideoLoaderConfig(fps=2))
assert first_load.shape != new_params_same_name.shape

# check no params
first_load = cached_load_video_frames(filepath=test_vid)
assert first_load.shape != new_params_same_name.shape
no_params_same_name = _generate_dataset(config=None)
assert first_load.shape != new_params_same_name.shape != no_params_same_name.shape

# multiple params in config
c1 = VideoLoaderConfig(scene_threshold=0.2)
c2 = VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)

first_load = cached_load_video_frames(filepath=test_vid, config=c1)
new_params_same_name = cached_load_video_frames(filepath=test_vid, config=c2)
first_load = _generate_dataset(config=VideoLoaderConfig(scene_threshold=0.2))
new_params_same_name = _generate_dataset(
config=VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)
)
assert first_load.shape != new_params_same_name.shape


Expand Down Expand Up @@ -506,39 +515,55 @@ def test_validate_total_frames():
assert config.total_frames == 8


def test_caching(tmp_path, caplog):
def test_caching(tmp_path, caplog, train_metadata):
cache = tmp_path / "video_cache"
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]

# prep labels for one video
labels = (
train_metadata[train_metadata.split == "train"]
.set_index("filepath")
.filter(regex="species")
.head(1)
)

# no caching by default
_ = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
_ = FfmpegZambaVideoDataset(
annotations=labels,
).__getitem__(index=0)
assert not cache.exists()

# caching can be specifed in config
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=1, cache_dir=cache)
)
_ = FfmpegZambaVideoDataset(
annotations=labels, video_loader_config=VideoLoaderConfig(fps=1, cache_dir=cache)
).__getitem__(index=0)

# one file in cache
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1
shutil.rmtree(cache)

# or caching can be specified in environment variable
with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
_ = cached_load_video_frames(filepath=test_vid)
_ = FfmpegZambaVideoDataset(
annotations=labels,
).__getitem__(index=0)
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1

# changing cleanup in config does not prompt new hashing of videos
with mock.patch.dict(os.environ, {"LOG_LEVEL": "DEBUG"}):
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(cleanup_cache=True)
)
_ = FfmpegZambaVideoDataset(
annotations=labels, video_loader_config=VideoLoaderConfig(cleanup_cache=True)
).__getitem__(index=0)

assert "Loading from cache" in caplog.text

# if no config is passed, this is equivalent to specifying None/False in all non-cache related VLC params
no_config = cached_load_video_frames(filepath=test_vid, config=None)
config_with_nones = cached_load_video_frames(
filepath=test_vid,
config=VideoLoaderConfig(
no_config = FfmpegZambaVideoDataset(annotations=labels, video_loader_config=None).__getitem__(
index=0
)[0]

config_with_nones = FfmpegZambaVideoDataset(
annotations=labels,
video_loader_config=VideoLoaderConfig(
crop_bottom_pixels=None,
i_frames=False,
scene_threshold=None,
Expand All @@ -555,5 +580,6 @@ def test_caching(tmp_path, caplog):
model_input_height=None,
model_input_width=None,
),
)
).__getitem__(index=0)[0]

assert np.array_equal(no_config, config_with_nones)
44 changes: 9 additions & 35 deletions zamba/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def validate_total_frames(cls, values):


class npy_cache:
def __init__(self, path: Optional[Path] = None, cleanup: bool = False):
self.tmp_path = path
def __init__(self, cache_path: Optional[Path] = None, cleanup: bool = False):
self.cache_path = cache_path
self.cleanup = cleanup

def __call__(self, f):
Expand Down Expand Up @@ -358,7 +358,7 @@ def _wrapped(*args, **kwargs):
if isinstance(vid_path, S3Path):
vid_path = AnyPath(vid_path.key)

npy_path = self.tmp_path / hash_str / vid_path.with_suffix(".npy")
npy_path = self.cache_path / hash_str / vid_path.with_suffix(".npy")
# make parent directories since we're using absolute paths
npy_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -372,32 +372,24 @@ def _wrapped(*args, **kwargs):
logger.debug(f"Wrote to cache {npy_path}: size {npy_path.stat().st_size}")
return loaded_video

if self.tmp_path is not None:
if self.cache_path is not None:
return _wrapped
else:
return f

def __del__(self):
if hasattr(self, "tmp_path") and self.cleanup and self.tmp_path.exists():
if self.tmp_path.parents[0] == tempfile.gettempdir():
logger.info(f"Deleting cache dir {self.tmp_path}.")
rmtree(self.tmp_path)
if hasattr(self, "cache_path") and self.cleanup and self.cache_path.exists():
if self.cache_path.parents[0] == tempfile.gettempdir():
logger.info(f"Deleting cache dir {self.cache_path}.")
rmtree(self.cache_path)
else:
logger.warning(
"Bravely refusing to delete directory that is not a subdirectory of the "
"system temp directory. If you really want to delete, do so manually using:\n "
f"rm -r {self.tmp_path}"
f"rm -r {self.cache_path}"
)


def npy_cache_factory(path, callable, cleanup):
@npy_cache(path=path, cleanup=cleanup)
def decorated_callable(*args, **kwargs):
return callable(*args, **kwargs)

return decorated_callable


def load_video_frames(
filepath: os.PathLike,
config: Optional[VideoLoaderConfig] = None,
Expand Down Expand Up @@ -503,21 +495,3 @@ def load_video_frames(
arr = ensure_frame_number(arr, total_frames=config.total_frames)

return arr


def cached_load_video_frames(filepath: os.PathLike, config: Optional[VideoLoaderConfig] = None):
"""Loads frames from videos using fast ffmpeg commands and caches to .npy file
if config.cache_dir is not None.
Args:
filepath (os.PathLike): Path to the video.
config (VideoLoaderConfig): Configuration for video loading.
"""
if config is None:
# get environment variable for cache if it exists
config = VideoLoaderConfig()

decorated_load_video_frames = npy_cache_factory(
path=config.cache_dir, callable=load_video_frames, cleanup=config.cleanup_cache
)
return decorated_load_video_frames(filepath=filepath, config=config)
2 changes: 1 addition & 1 deletion zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def instantiate_model(
hparams = yaml.safe_load(f)

else:
if not Path(checkpoint).exists():
if not (model_cache_dir / checkpoint).exists():
logger.info("Downloading weights for model.")
checkpoint = download_weights(
filename=str(checkpoint),
Expand Down
12 changes: 11 additions & 1 deletion zamba/pytorch/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms.transforms

from zamba.data.video import cached_load_video_frames, VideoLoaderConfig
from zamba.data.video import npy_cache, load_video_frames, VideoLoaderConfig


def get_datasets(
Expand Down Expand Up @@ -87,6 +87,11 @@ def __init__(
self.targets = annotations

self.transform = transform

# get environment variable for cache if it exists
if video_loader_config is None:
video_loader_config = VideoLoaderConfig()

self.video_loader_config = video_loader_config

super().__init__(root=None, transform=transform)
Expand All @@ -96,6 +101,11 @@ def __len__(self):

def __getitem__(self, index: int):
try:
cached_load_video_frames = npy_cache(
cache_path=self.video_loader_config.cache_dir,
cleanup=self.video_loader_config.cleanup_cache,
)(load_video_frames)

video = cached_load_video_frames(
filepath=self.video_paths[index], config=self.video_loader_config
)
Expand Down

0 comments on commit c41ae7e

Please sign in to comment.