Skip to content

Commit

Permalink
helpful error on unhashable parameters (#16049)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Nov 18, 2024
1 parent d1e15a0 commit fdca916
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 12 deletions.
15 changes: 13 additions & 2 deletions src/prefect/cache_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing_extensions import Self

from prefect.context import TaskRunContext
from prefect.exceptions import HashError
from prefect.utilities.hashing import hash_objects

if TYPE_CHECKING:
Expand Down Expand Up @@ -223,7 +224,6 @@ def compute_key(
lines = task_ctx.task.fn.__code__.co_code
else:
raise

return hash_objects(lines, raise_on_failure=True)


Expand Down Expand Up @@ -293,7 +293,18 @@ def compute_key(
if key not in exclude:
hashed_inputs[key] = val

return hash_objects(hashed_inputs, raise_on_failure=True)
try:
return hash_objects(hashed_inputs, raise_on_failure=True)
except HashError as exc:
msg = (
f"{exc}\n\n"
"This often occurs when task inputs contain objects that cannot be cached "
"like locks, file handles, or other system resources.\n\n"
"To resolve this, you can:\n"
" 1. Exclude these arguments by defining a custom `cache_key_fn`\n"
" 2. Disable caching by passing `cache_policy=NONE`\n"
)
raise ValueError(msg) from exc

def __sub__(self, other: str) -> "CachePolicy":
if not isinstance(other, str):
Expand Down
4 changes: 4 additions & 0 deletions src/prefect/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,7 @@ class ProfileSettingsValidationError(PrefectException):

def __init__(self, errors: List[Tuple[Any, ValidationError]]) -> None:
self.errors = errors


class HashError(PrefectException):
"""Raised when hashing objects fails"""
35 changes: 28 additions & 7 deletions src/prefect/utilities/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import cloudpickle

from prefect.exceptions import HashError
from prefect.serializers import JSONSerializer

if sys.version_info[:2] >= (3, 9):
Expand Down Expand Up @@ -53,19 +54,39 @@ def hash_objects(
) -> Optional[str]:
"""
Attempt to hash objects by dumping to JSON or serializing with cloudpickle.
On failure of both, `None` will be returned; to raise on failure, set
`raise_on_failure=True`.
Args:
*args: Positional arguments to hash
hash_algo: Hash algorithm to use
raise_on_failure: If True, raise exceptions instead of returning None
**kwargs: Keyword arguments to hash
Returns:
A hash string or None if hashing failed
Raises:
HashError: If objects cannot be hashed and raise_on_failure is True
"""
json_error = None
pickle_error = None

try:
serializer = JSONSerializer(dumps_kwargs={"sort_keys": True})
return stable_hash(serializer.dumps((args, kwargs)), hash_algo=hash_algo)
except Exception:
pass
except Exception as e:
json_error = str(e)

try:
return stable_hash(cloudpickle.dumps((args, kwargs)), hash_algo=hash_algo)
except Exception:
if raise_on_failure:
raise
except Exception as e:
pickle_error = str(e)

if raise_on_failure:
msg = (
"Unable to create hash - objects could not be serialized.\n"
f" JSON error: {json_error}\n"
f" Pickle error: {pickle_error}"
)
raise HashError(msg)

return None
79 changes: 77 additions & 2 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import inspect
import json
import threading
import time
from asyncio import Event, sleep
from functools import partial
Expand Down Expand Up @@ -1969,6 +1970,80 @@ def foo(x):
):
foo(1)

async def test_unhashable_input_provides_helpful_error(self, caplog):
"""Test that trying to cache a task with unhashable inputs provides helpful error message"""
lock = threading.Lock()

@task(persist_result=True)
def foo(x, lock_obj):
return x

foo(42, lock_obj=lock)

error_msg = caplog.text

# First we see the cache policy's message
assert (
"This often occurs when task inputs contain objects that cannot be cached"
in error_msg
)
assert "like locks, file handles, or other system resources." in error_msg
assert "To resolve this, you can:" in error_msg
assert (
"1. Exclude these arguments by defining a custom `cache_key_fn`"
in error_msg
)
assert "2. Disable caching by passing `cache_policy=NONE`" in error_msg

# Then we see the original HashError details
assert "Unable to create hash - objects could not be serialized." in error_msg
assert (
"JSON error: Unable to serialize unknown type: <class '_thread.lock'>"
in error_msg
)
assert "Pickle error: cannot pickle '_thread.lock' object" in error_msg

async def test_unhashable_input_workarounds(self):
"""Test workarounds for handling unhashable inputs"""
lock = threading.Lock()

# Solution 1: Use cache_key_fn to exclude problematic argument
def cache_on_x_only(context, parameters):
return str(parameters.get("x"))

@task(cache_key_fn=cache_on_x_only, persist_result=True)
def foo_with_key_fn(x, lock_obj):
return x

# Solution 2: Disable caching entirely
@task(cache_policy=NONE, persist_result=True)
def foo_with_none_policy(x, lock_obj):
return x

@flow
def test_flow():
# Both approaches should work without errors
return (
foo_with_key_fn(42, lock_obj=lock, return_state=True),
foo_with_key_fn(42, lock_obj=lock, return_state=True),
foo_with_none_policy(42, lock_obj=lock, return_state=True),
foo_with_none_policy(42, lock_obj=lock, return_state=True),
)

s1, s2, s3, s4 = test_flow()

# Key fn approach should still cache based on x
assert s1.name == "Completed"
assert s2.name == "Cached"
assert await s1.result() == 42
assert await s2.result() == 42

# NONE policy approach should never cache
assert s3.name == "Completed"
assert s4.name == "Completed"
assert await s3.result() == 42
assert await s4.result() == 42


class TestCacheFunctionBuiltins:
async def test_task_input_hash_within_flows(
Expand Down Expand Up @@ -2038,7 +2113,7 @@ def __init__(self, x):
self.x = x

def __eq__(self, other) -> bool:
return type(self) == type(other) and self.x == other.x
return type(self) is type(other) and self.x == other.x

@task(
cache_key_fn=task_input_hash,
Expand Down Expand Up @@ -2071,7 +2146,7 @@ def __init__(self, x):
self.x = x

def __eq__(self, other) -> bool:
return type(self) == type(other) and self.x == other.x
return type(self) is type(other) and self.x == other.x

@task(
cache_key_fn=task_input_hash,
Expand Down
30 changes: 29 additions & 1 deletion tests/utilities/test_hashing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import hashlib
import threading
from unittest.mock import MagicMock

import pytest

from prefect.utilities.hashing import file_hash, stable_hash
from prefect.exceptions import HashError
from prefect.utilities.hashing import file_hash, hash_objects, stable_hash


@pytest.mark.parametrize(
Expand Down Expand Up @@ -55,3 +58,28 @@ def test_file_hash_hashes(self, tmp_path):
assert val == hashlib.md5(b"0").hexdigest()
# Check if the hash is stable
assert val == "cfcd208495d565ef66e7dff9f98764da"


class TestHashObjects:
def test_hash_objects_handles_unhashable_objects_gracefully(self):
"""Test that unhashable objects return None by default"""
lock = threading.Lock()
result = hash_objects({"data": "hello", "lock": lock})
assert result is None

def test_hash_objects_raises_with_helpful_message(self):
"""Test that unhashable objects raise HashError when raise_on_failure=True"""
lock = threading.Lock()
mock_file = MagicMock()
mock_file.__str__ = lambda _: "<file object>"

with pytest.raises(HashError) as exc:
hash_objects(
{"data": "hello", "lock": lock, "file": mock_file},
raise_on_failure=True,
)

error_msg = str(exc.value)
assert "Unable to create hash" in error_msg
assert "JSON error" in error_msg
assert "Pickle error" in error_msg

0 comments on commit fdca916

Please sign in to comment.