Skip to content

Commit

Permalink
Add a flag to disable caching (#92)
Browse files Browse the repository at this point in the history
This PR adds a flag that can disable caching
  • Loading branch information
cthoyt authored Jan 15, 2025
1 parent 22861de commit 55abd65
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/pystow/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ class Cached(Generic[X], ABC):
def __init__(
self,
path: str | Path,
*,
force: bool = False,
cache: bool = True,
) -> None:
"""Instantiate the decorator.
:param path: The path to the cache for the file
:param cache: Should caching be done? Defaults to true, turn off for debugging purposes
:param force: Should a pre-existing file be disregared/overwritten?
"""
self.path = Path(path)
self.force = force
self.cache = cache

def __call__(self, func: Getter[X]) -> Getter[X]:
"""Apply this instance as a decorator.
Expand All @@ -69,6 +73,9 @@ def __call__(self, func: Getter[X]) -> Getter[X]:

@functools.wraps(func)
def _wrapped() -> X:
if not self.cache:
return func()

if self.path.is_file() and not self.force:
return self.load()
logger.debug("no cache found at %s", self.path)
Expand Down Expand Up @@ -158,6 +165,7 @@ class CachedDataFrame(Cached["pd.DataFrame"]):
def __init__(
self,
path: str | Path,
cache: bool = True,
force: bool = False,
sep: str | None = None,
dtype: Any | None = None,
Expand All @@ -172,7 +180,7 @@ def __init__(
:param read_csv_kwargs: Additional kwargs to pass to :func:`pd.read_csv`.
:raises ValueError: if sep is given as a kwarg and also in ``read_csv_kwargs``.
"""
super().__init__(path=path, force=force)
super().__init__(path=path, cache=cache, force=force)
self.read_csv_kwargs = read_csv_kwargs or {}
if "sep" not in self.read_csv_kwargs:
self.sep = sep or "\t"
Expand Down
21 changes: 21 additions & 0 deletions tests/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,24 @@ def _f2():

self.assertEqual(EXPECTED_2, _f2()) # overwrites the file
self.assertEqual(EXPECTED_2, _f1())

def test_no_cache(self) -> None:
"""Test that no caching happens."""
path = self.directory.joinpath("test.pkl")
sentinel_value = 5

self.assertFalse(path.is_file())

@CachedPickle(path=path, cache=False)
def _f1() -> int:
return sentinel_value

self.assertFalse(path.is_file(), msg="function has not been called")

# check the following twice, just for good measure!
for _ in range(2):
self.assertEqual(sentinel_value, _f1())
self.assertFalse(
path.is_file(),
msg="file should not have been created since caching was turned off",
)

0 comments on commit 55abd65

Please sign in to comment.