diff --git a/tests/test_load_video_frames.py b/tests/test_load_video_frames.py index 08a46747..91a29392 100644 --- a/tests/test_load_video_frames.py +++ b/tests/test_load_video_frames.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import pytest import shutil import subprocess @@ -583,3 +584,14 @@ def test_caching(tmp_path, caplog, train_metadata): ).__getitem__(index=0)[0] assert np.array_equal(no_config, config_with_nones) + + +def test_validate_video_cache_dir(): + with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": "example_cache_dir"}): + config = VideoLoaderConfig() + assert config.cache_dir == Path("example_cache_dir") + + for cache in ["", 0]: + with mock.patch.dict(os.environ, {"VIDEO_CACHE_DIR": str(cache)}): + config = VideoLoaderConfig() + assert config.cache_dir is None diff --git a/zamba/data/video.py b/zamba/data/video.py index c73f7a0e..813e435c 100644 --- a/zamba/data/video.py +++ b/zamba/data/video.py @@ -220,6 +220,9 @@ def validate_video_cache_dir(cls, cache_dir): if cache_dir is None: cache_dir = os.getenv("VIDEO_CACHE_DIR", None) + if cache_dir in ["", "0"]: + cache_dir = None + if cache_dir is not None: cache_dir = Path(cache_dir) cache_dir.mkdir(parents=True, exist_ok=True)