Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify caching #145

Merged
merged 8 commits into from
Oct 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions tests/test_load_video_frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pydantic import BaseModel, ValidationError

from zamba.data.video import (
cached_load_video_frames,
load_video_frames,
VideoLoaderConfig,
MegadetectorLiteYoloXConfig,
VideoLoaderConfig,
)
from zamba.pytorch.dataloaders import FfmpegZambaVideoDataset

from conftest import ASSETS_DIR, TEST_VIDEOS_DIR

Expand Down Expand Up @@ -384,32 +384,41 @@ def test_load_video_frames(case: Case, video_metadata: Dict[str, Any]):
assert video_shape[field] == value


def test_same_filename_new_kwargs(tmp_path):
def test_same_filename_new_kwargs(tmp_path, train_metadata):
"""Test that load_video_frames does not load the npz file if the params change."""
# use first test video
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]
cache = tmp_path / "test_cache"

# prep labels for one video
labels = (
train_metadata[train_metadata.split == "train"]
.set_index("filepath")
.filter(regex="species")
.head(1)
)

def _generate_dataset(config):
"""Return loaded video from FFmpegZambaVideoDataset."""
return FfmpegZambaVideoDataset(annotations=labels, video_loader_config=config).__getitem__(
index=0
)[0]

with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
# confirm cache is set in environment variable
assert os.environ["VIDEO_CACHE_DIR"] == str(cache)

first_load = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
new_params_same_name = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=2)
)
first_load = _generate_dataset(config=VideoLoaderConfig(fps=1))
new_params_same_name = _generate_dataset(config=VideoLoaderConfig(fps=2))
assert first_load.shape != new_params_same_name.shape

# check no params
first_load = cached_load_video_frames(filepath=test_vid)
assert first_load.shape != new_params_same_name.shape
no_params_same_name = _generate_dataset(config=None)
assert first_load.shape != new_params_same_name.shape != no_params_same_name.shape

# multiple params in config
c1 = VideoLoaderConfig(scene_threshold=0.2)
c2 = VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)

first_load = cached_load_video_frames(filepath=test_vid, config=c1)
new_params_same_name = cached_load_video_frames(filepath=test_vid, config=c2)
first_load = _generate_dataset(config=VideoLoaderConfig(scene_threshold=0.2))
new_params_same_name = _generate_dataset(
config=VideoLoaderConfig(scene_threshold=0.2, crop_bottom_pixels=2)
)
assert first_load.shape != new_params_same_name.shape


Expand Down Expand Up @@ -506,39 +515,55 @@ def test_validate_total_frames():
assert config.total_frames == 8


def test_caching(tmp_path, caplog):
def test_caching(tmp_path, caplog, train_metadata):
cache = tmp_path / "video_cache"
test_vid = [f for f in TEST_VIDEOS_DIR.rglob("*") if f.is_file()][0]

# prep labels for one video
labels = (
train_metadata[train_metadata.split == "train"]
.set_index("filepath")
.filter(regex="species")
.head(1)
)

# no caching by default
_ = cached_load_video_frames(filepath=test_vid, config=VideoLoaderConfig(fps=1))
_ = FfmpegZambaVideoDataset(
annotations=labels,
).__getitem__(index=0)
assert not cache.exists()

# caching can be specifed in config
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(fps=1, cache_dir=cache)
)
_ = FfmpegZambaVideoDataset(
annotations=labels, video_loader_config=VideoLoaderConfig(fps=1, cache_dir=cache)
).__getitem__(index=0)

# one file in cache
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1
shutil.rmtree(cache)

# or caching can be specified in environment variable
with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}):
_ = cached_load_video_frames(filepath=test_vid)
_ = FfmpegZambaVideoDataset(
annotations=labels,
).__getitem__(index=0)
assert len([f for f in cache.rglob("*") if f.is_file()]) == 1

# changing cleanup in config does not prompt new hashing of videos
with mock.patch.dict(os.environ, {"LOG_LEVEL": "DEBUG"}):
_ = cached_load_video_frames(
filepath=test_vid, config=VideoLoaderConfig(cleanup_cache=True)
)
_ = FfmpegZambaVideoDataset(
annotations=labels, video_loader_config=VideoLoaderConfig(cleanup_cache=True)
).__getitem__(index=0)

assert "Loading from cache" in caplog.text

# if no config is passed, this is equivalent to specifying None/False in all non-cache related VLC params
no_config = cached_load_video_frames(filepath=test_vid, config=None)
config_with_nones = cached_load_video_frames(
filepath=test_vid,
config=VideoLoaderConfig(
no_config = FfmpegZambaVideoDataset(annotations=labels, video_loader_config=None).__getitem__(
index=0
)[0]

config_with_nones = FfmpegZambaVideoDataset(
annotations=labels,
video_loader_config=VideoLoaderConfig(
crop_bottom_pixels=None,
i_frames=False,
scene_threshold=None,
Expand All @@ -555,5 +580,6 @@ def test_caching(tmp_path, caplog):
model_input_height=None,
model_input_width=None,
),
)
).__getitem__(index=0)[0]

assert np.array_equal(no_config, config_with_nones)
44 changes: 9 additions & 35 deletions zamba/data/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,8 @@ def validate_total_frames(cls, values):


class npy_cache:
def __init__(self, path: Optional[Path] = None, cleanup: bool = False):
self.tmp_path = path
def __init__(self, cache_path: Optional[Path] = None, cleanup: bool = False):
self.cache_path = cache_path
self.cleanup = cleanup

def __call__(self, f):
Expand Down Expand Up @@ -358,7 +358,7 @@ def _wrapped(*args, **kwargs):
if isinstance(vid_path, S3Path):
vid_path = AnyPath(vid_path.key)

npy_path = self.tmp_path / hash_str / vid_path.with_suffix(".npy")
npy_path = self.cache_path / hash_str / vid_path.with_suffix(".npy")
# make parent directories since we're using absolute paths
npy_path.parent.mkdir(parents=True, exist_ok=True)

Expand All @@ -372,32 +372,24 @@ def _wrapped(*args, **kwargs):
logger.debug(f"Wrote to cache {npy_path}: size {npy_path.stat().st_size}")
return loaded_video

if self.tmp_path is not None:
if self.cache_path is not None:
return _wrapped
else:
return f

def __del__(self):
if hasattr(self, "tmp_path") and self.cleanup and self.tmp_path.exists():
if self.tmp_path.parents[0] == tempfile.gettempdir():
logger.info(f"Deleting cache dir {self.tmp_path}.")
rmtree(self.tmp_path)
if hasattr(self, "cache_path") and self.cleanup and self.cache_path.exists():
if self.cache_path.parents[0] == tempfile.gettempdir():
logger.info(f"Deleting cache dir {self.cache_path}.")
rmtree(self.cache_path)
else:
logger.warning(
"Bravely refusing to delete directory that is not a subdirectory of the "
"system temp directory. If you really want to delete, do so manually using:\n "
f"rm -r {self.tmp_path}"
f"rm -r {self.cache_path}"
)


def npy_cache_factory(path, callable, cleanup):
@npy_cache(path=path, cleanup=cleanup)
def decorated_callable(*args, **kwargs):
return callable(*args, **kwargs)

return decorated_callable


def load_video_frames(
filepath: os.PathLike,
config: Optional[VideoLoaderConfig] = None,
Expand Down Expand Up @@ -503,21 +495,3 @@ def load_video_frames(
arr = ensure_frame_number(arr, total_frames=config.total_frames)

return arr


def cached_load_video_frames(filepath: os.PathLike, config: Optional[VideoLoaderConfig] = None):
"""Loads frames from videos using fast ffmpeg commands and caches to .npy file
if config.cache_dir is not None.

Args:
filepath (os.PathLike): Path to the video.
config (VideoLoaderConfig): Configuration for video loading.
"""
if config is None:
# get environment variable for cache if it exists
config = VideoLoaderConfig()

decorated_load_video_frames = npy_cache_factory(
path=config.cache_dir, callable=load_video_frames, cleanup=config.cleanup_cache
)
return decorated_load_video_frames(filepath=filepath, config=config)
2 changes: 1 addition & 1 deletion zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def instantiate_model(
hparams = yaml.safe_load(f)

else:
if not Path(checkpoint).exists():
if not (model_cache_dir / checkpoint).exists():
logger.info("Downloading weights for model.")
checkpoint = download_weights(
filename=str(checkpoint),
Expand Down
12 changes: 11 additions & 1 deletion zamba/pytorch/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms.transforms

from zamba.data.video import cached_load_video_frames, VideoLoaderConfig
from zamba.data.video import npy_cache, load_video_frames, VideoLoaderConfig


def get_datasets(
Expand Down Expand Up @@ -87,6 +87,11 @@ def __init__(
self.targets = annotations

self.transform = transform

# get environment variable for cache if it exists
if video_loader_config is None:
video_loader_config = VideoLoaderConfig()

self.video_loader_config = video_loader_config

super().__init__(root=None, transform=transform)
Expand All @@ -96,6 +101,11 @@ def __len__(self):

def __getitem__(self, index: int):
try:
cached_load_video_frames = npy_cache(
r-b-g-b marked this conversation as resolved.
Show resolved Hide resolved
cache_path=self.video_loader_config.cache_dir,
cleanup=self.video_loader_config.cleanup_cache,
)(load_video_frames)

video = cached_load_video_frames(
filepath=self.video_paths[index], config=self.video_loader_config
)
Expand Down