Skip to content

Commit

Permalink
Improved localisation of errors to particular arguments.
Browse files Browse the repository at this point in the history
This commit does two things.

First of all, it fixes a bug in which we forgot to add in default arguments when calling `_get_problem_arg`. This meant that in practice, if you had a default argument, then the very first argument would be what is reported.

Second, it rearranges things into a couple of extra stack frames, for an easier debugging experience.
- When looking at the main error message, this will now occur on the line that actually raised it, and not the `finally: pop_stack_memo()` line.
- The argument-specific `_get_problem_arg` error is what is now attached as the cause, rather than the overall failure of the whole typechecking.
  • Loading branch information
patrick-kidger committed Jul 1, 2024
1 parent b0bbff9 commit 6236167
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 86 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
182 changes: 98 additions & 84 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import itertools as it
import sys
import warnings
from typing import Any, get_args, get_origin, get_type_hints, overload
from typing import Any, get_args, get_origin, get_type_hints, NoReturn, overload

from jaxtyping import AbstractArray

Expand Down Expand Up @@ -405,30 +405,15 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore
full_fn = typechecker(full_fn)
param_fn = typechecker(param_fn)

@ft.wraps(fn)
def wrapped_fn(*args, **kwargs):
__tracebackhide__ = True
if (
config.jaxtyping_disable
or getattr(fn, "__no_type_check__", False)
or getattr(wrapped_fn, "__no_type_check__", False)
):
return fn(*args, **kwargs)

# Raise bind-time errors before we do any shape analysis. (I.e. skip
# the pointless jaxtyping information for a non-typechecking failure.)
bound = param_signature.bind(*args, **kwargs)
bound.apply_defaults()

memos = push_shape_memo(bound.arguments)
def wrapped_fn_impl(args, kwargs, bound, memos):
# First type-check just the parameters before the function is
# called.
try:
# First type-check just the parameters before the function is
# called.
param_fn(*args, **kwargs)
except AnnotationError:
raise
except Exception:
try:
param_fn(*args, **kwargs)
except AnnotationError:
raise
except Exception as e:
argmsg = _get_problem_arg(
param_signature,
args,
Expand All @@ -437,6 +422,8 @@ def wrapped_fn(*args, **kwargs):
module,
typechecker,
)
except TypeCheckError as e:
argmsg = str(e)
try:
module_name = fn.__module__
qualname = fn.__qualname__
Expand All @@ -458,62 +445,81 @@ def wrapped_fn(*args, **kwargs):
else:
raise TypeCheckError(msg) from e

# Actually call the function.
out = fn(*args, **kwargs)

if full_signature.return_annotation is not inspect.Signature.empty:
# Now type-check the return value. We need to include the
# parameters in the type-checking here in case there are any
# type variables shared across the parameters and return.
#
# Incidentally this does mean that if `fn` mutates its arguments
# so that they no longer satisfy their type annotations, this
# will throw an error here. But that's like, super weird, so
# don't do that. An error in that scenario is probably still
# desirable.
#
# There is a small performance concern here when used in
# non-jit'd contexts, like PyTorch, due to the duplicate
# checking of the parameters. Unfortunately there doesn't seem
# to be a way around that, so c'est la vie.
kwargs[output_name] = out
# Actually call the function.
out = fn(*args, **kwargs)

if full_signature.return_annotation is not inspect.Signature.empty:
# Now type-check the return value. We need to include the
# parameters in the type-checking here in case there are any
# type variables shared across the parameters and return.
#
# Incidentally this does mean that if `fn` mutates its arguments
# so that they no longer satisfy their type annotations, this
# will throw an error here. But that's like, super weird, so
# don't do that. An error in that scenario is probably still
# desirable.
#
# There is a small performance concern here when used in
# non-jit'd contexts, like PyTorch, due to the duplicate
# checking of the parameters. Unfortunately there doesn't seem
# to be a way around that, so c'est la vie.
kwargs[output_name] = out
try:
full_fn(*args, **kwargs)
except AnnotationError:
raise
except Exception as e:
try:
full_fn(*args, **kwargs)
except AnnotationError:
raise
except Exception as e:
try:
module_name = fn.__module__
qualname = fn.__qualname__
except AttributeError:
module_name = fn.__class__.__module__
qualname = fn.__class__.__qualname__
param_values = _pformat(bound.arguments, short_self=True)
return_value = _pformat(out, short_self=False)
param_hints = _remove_typing(param_signature)
return_hint = _remove_typing(
full_signature.return_annotation
)
if return_hint.startswith(
"<class '"
) and return_hint.endswith("'>"):
return_hint = return_hint[8:-2]
msg = (
"Type-check error whilst checking the return value "
f"of {module_name}.{qualname}.\n"
f"Actual value: {return_value}\n"
f"Expected type: {return_hint}.\n"
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
else:
raise TypeCheckError(msg) from e
module_name = fn.__module__
qualname = fn.__qualname__
except AttributeError:
module_name = fn.__class__.__module__
qualname = fn.__class__.__qualname__
param_values = _pformat(bound.arguments, short_self=True)
return_value = _pformat(out, short_self=False)
param_hints = _remove_typing(param_signature)
return_hint = _remove_typing(full_signature.return_annotation)
if return_hint.startswith("<class '") and return_hint.endswith(
"'>"
):
return_hint = return_hint[8:-2]
msg = (
"Type-check error whilst checking the return value "
f"of {module_name}.{qualname}.\n"
f"Actual value: {return_value}\n"
f"Expected type: {return_hint}.\n"
"----------------------\n"
f"Called with parameters: {param_values}\n"
f"Parameter annotations: {param_hints}.\n"
+ shape_str(memos)
)
if config.jaxtyping_remove_typechecker_stack:
raise TypeCheckError(msg) from None
else:
raise TypeCheckError(msg) from e

return out
return out

@ft.wraps(fn)
def wrapped_fn(*args, **kwargs):
__tracebackhide__ = True
if (
config.jaxtyping_disable
or getattr(fn, "__no_type_check__", False)
or getattr(wrapped_fn, "__no_type_check__", False)
):
return fn(*args, **kwargs)

# Raise bind-time errors before we do any shape analysis. (I.e. skip
# the pointless jaxtyping information for a non-typechecking failure.)
bound = param_signature.bind(*args, **kwargs)
bound.apply_defaults()

memos = push_shape_memo(bound.arguments)
try:
# Put this in a separate frame to make debugging easier, without
# just always ending up on the `pop_shape_memo` line below.
return wrapped_fn_impl(args, kwargs, bound, memos)
finally:
pop_shape_memo()

Expand Down Expand Up @@ -727,9 +733,13 @@ def _make_argpiece(p, name_to_annotation, name_to_default):

def _get_problem_arg(
param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker
) -> str:
) -> NoReturn:
"""Determines which argument was likely to be the problematic one responsible for
raising a type-check error.
It returns the result by raising an Exception: you should grab this and extract the
string out of it. We do this, rather than just returning the value, to aid debugging
by making it possible to walk down the stack to the issue.
"""
# No performance concerns, as this is only used when we're about to raise an error
# anyway.
Expand All @@ -739,12 +749,16 @@ def _get_problem_arg(
for p_name, p in param_signature.parameters.items():
if p_name == keep_name:
new_parameters.append(
inspect.Parameter(p.name, p.kind, annotation=p.annotation)
inspect.Parameter(
p.name, p.kind, default=p.default, annotation=p.annotation
)
)
assert keep_annotation is sentinel
keep_annotation = _remove_typing(p.annotation)
else:
new_parameters.append(inspect.Parameter(p.name, p.kind))
new_parameters.append(
inspect.Parameter(p.name, p.kind, default=p.default)
)
assert keep_annotation is not sentinel
new_signature = inspect.Signature(new_parameters)
fn = _make_fn_with_signature(
Expand All @@ -753,17 +767,17 @@ def _get_problem_arg(
fn = typechecker(fn) # but no `jaxtyped`; keep the same environment.
try:
fn(*args, **kwargs)
except Exception:
except Exception as e:
keep_value = _pformat(arguments[keep_name], short_self=False)
return (
raise TypeCheckError(
f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n"
f"Actual value: {keep_value}\n"
f"Expected type: {keep_annotation}."
)
) from e
else:
# Could not localise the problem to a single argument -- probably due to
# e.g. a mismatched typevar, which each individual argument is okay with.
return ""
raise TypeCheckError("")


def _remove_typing(x):
Expand Down

0 comments on commit 6236167

Please sign in to comment.