diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a301d3d..6285997 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,8 +21,8 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.1.7 hooks: + - id: ruff-format # formatter + types_or: [ python, pyi, jupyter ] - id: ruff # linter types_or: [ python, pyi, jupyter ] args: [ --fix ] - - id: ruff-format # formatter - types_or: [ python, pyi, jupyter ] diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index c32c6e4..2d0eac7 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -24,7 +24,7 @@ import itertools as it import sys import warnings -from typing import Any, get_args, get_origin, get_type_hints, overload +from typing import Any, get_args, get_origin, get_type_hints, NoReturn, overload from jaxtyping import AbstractArray @@ -405,30 +405,15 @@ def wrapped_fn(*args, **kwargs): # pyright: ignore full_fn = typechecker(full_fn) param_fn = typechecker(param_fn) - @ft.wraps(fn) - def wrapped_fn(*args, **kwargs): - __tracebackhide__ = True - if ( - config.jaxtyping_disable - or getattr(fn, "__no_type_check__", False) - or getattr(wrapped_fn, "__no_type_check__", False) - ): - return fn(*args, **kwargs) - - # Raise bind-time errors before we do any shape analysis. (I.e. skip - # the pointless jaxtyping information for a non-typechecking failure.) - bound = param_signature.bind(*args, **kwargs) - bound.apply_defaults() - - memos = push_shape_memo(bound.arguments) + def wrapped_fn_impl(args, kwargs, bound, memos): + # First type-check just the parameters before the function is + # called. try: - # First type-check just the parameters before the function is - # called. + param_fn(*args, **kwargs) + except AnnotationError: + raise + except Exception: try: - param_fn(*args, **kwargs) - except AnnotationError: - raise - except Exception as e: argmsg = _get_problem_arg( param_signature, args, @@ -437,6 +422,8 @@ def wrapped_fn(*args, **kwargs): module, typechecker, ) + except TypeCheckError as e: + argmsg = str(e) try: module_name = fn.__module__ qualname = fn.__qualname__ @@ -458,62 +445,81 @@ def wrapped_fn(*args, **kwargs): else: raise TypeCheckError(msg) from e - # Actually call the function. - out = fn(*args, **kwargs) - - if full_signature.return_annotation is not inspect.Signature.empty: - # Now type-check the return value. We need to include the - # parameters in the type-checking here in case there are any - # type variables shared across the parameters and return. - # - # Incidentally this does mean that if `fn` mutates its arguments - # so that they no longer satisfy their type annotations, this - # will throw an error here. But that's like, super weird, so - # don't do that. An error in that scenario is probably still - # desirable. - # - # There is a small performance concern here when used in - # non-jit'd contexts, like PyTorch, due to the duplicate - # checking of the parameters. Unfortunately there doesn't seem - # to be a way around that, so c'est la vie. - kwargs[output_name] = out + # Actually call the function. + out = fn(*args, **kwargs) + + if full_signature.return_annotation is not inspect.Signature.empty: + # Now type-check the return value. We need to include the + # parameters in the type-checking here in case there are any + # type variables shared across the parameters and return. + # + # Incidentally this does mean that if `fn` mutates its arguments + # so that they no longer satisfy their type annotations, this + # will throw an error here. But that's like, super weird, so + # don't do that. An error in that scenario is probably still + # desirable. + # + # There is a small performance concern here when used in + # non-jit'd contexts, like PyTorch, due to the duplicate + # checking of the parameters. Unfortunately there doesn't seem + # to be a way around that, so c'est la vie. + kwargs[output_name] = out + try: + full_fn(*args, **kwargs) + except AnnotationError: + raise + except Exception as e: try: - full_fn(*args, **kwargs) - except AnnotationError: - raise - except Exception as e: - try: - module_name = fn.__module__ - qualname = fn.__qualname__ - except AttributeError: - module_name = fn.__class__.__module__ - qualname = fn.__class__.__qualname__ - param_values = _pformat(bound.arguments, short_self=True) - return_value = _pformat(out, short_self=False) - param_hints = _remove_typing(param_signature) - return_hint = _remove_typing( - full_signature.return_annotation - ) - if return_hint.startswith( - ""): - return_hint = return_hint[8:-2] - msg = ( - "Type-check error whilst checking the return value " - f"of {module_name}.{qualname}.\n" - f"Actual value: {return_value}\n" - f"Expected type: {return_hint}.\n" - "----------------------\n" - f"Called with parameters: {param_values}\n" - f"Parameter annotations: {param_hints}.\n" - + shape_str(memos) - ) - if config.jaxtyping_remove_typechecker_stack: - raise TypeCheckError(msg) from None - else: - raise TypeCheckError(msg) from e + module_name = fn.__module__ + qualname = fn.__qualname__ + except AttributeError: + module_name = fn.__class__.__module__ + qualname = fn.__class__.__qualname__ + param_values = _pformat(bound.arguments, short_self=True) + return_value = _pformat(out, short_self=False) + param_hints = _remove_typing(param_signature) + return_hint = _remove_typing(full_signature.return_annotation) + if return_hint.startswith("" + ): + return_hint = return_hint[8:-2] + msg = ( + "Type-check error whilst checking the return value " + f"of {module_name}.{qualname}.\n" + f"Actual value: {return_value}\n" + f"Expected type: {return_hint}.\n" + "----------------------\n" + f"Called with parameters: {param_values}\n" + f"Parameter annotations: {param_hints}.\n" + + shape_str(memos) + ) + if config.jaxtyping_remove_typechecker_stack: + raise TypeCheckError(msg) from None + else: + raise TypeCheckError(msg) from e - return out + return out + + @ft.wraps(fn) + def wrapped_fn(*args, **kwargs): + __tracebackhide__ = True + if ( + config.jaxtyping_disable + or getattr(fn, "__no_type_check__", False) + or getattr(wrapped_fn, "__no_type_check__", False) + ): + return fn(*args, **kwargs) + + # Raise bind-time errors before we do any shape analysis. (I.e. skip + # the pointless jaxtyping information for a non-typechecking failure.) + bound = param_signature.bind(*args, **kwargs) + bound.apply_defaults() + + memos = push_shape_memo(bound.arguments) + try: + # Put this in a separate frame to make debugging easier, without + # just always ending up on the `pop_shape_memo` line below. + return wrapped_fn_impl(args, kwargs, bound, memos) finally: pop_shape_memo() @@ -727,9 +733,13 @@ def _make_argpiece(p, name_to_annotation, name_to_default): def _get_problem_arg( param_signature: inspect.Signature, args, kwargs, arguments, module, typechecker -) -> str: +) -> NoReturn: """Determines which argument was likely to be the problematic one responsible for raising a type-check error. + + It returns the result by raising an Exception: you should grab this and extract the + string out of it. We do this, rather than just returning the value, to aid debugging + by making it possible to walk down the stack to the issue. """ # No performance concerns, as this is only used when we're about to raise an error # anyway. @@ -739,12 +749,16 @@ def _get_problem_arg( for p_name, p in param_signature.parameters.items(): if p_name == keep_name: new_parameters.append( - inspect.Parameter(p.name, p.kind, annotation=p.annotation) + inspect.Parameter( + p.name, p.kind, default=p.default, annotation=p.annotation + ) ) assert keep_annotation is sentinel keep_annotation = _remove_typing(p.annotation) else: - new_parameters.append(inspect.Parameter(p.name, p.kind)) + new_parameters.append( + inspect.Parameter(p.name, p.kind, default=p.default) + ) assert keep_annotation is not sentinel new_signature = inspect.Signature(new_parameters) fn = _make_fn_with_signature( @@ -753,17 +767,17 @@ def _get_problem_arg( fn = typechecker(fn) # but no `jaxtyped`; keep the same environment. try: fn(*args, **kwargs) - except Exception: + except Exception as e: keep_value = _pformat(arguments[keep_name], short_self=False) - return ( + raise TypeCheckError( f"\nThe problem arose whilst typechecking parameter '{keep_name}'.\n" f"Actual value: {keep_value}\n" f"Expected type: {keep_annotation}." - ) + ) from e else: # Could not localise the problem to a single argument -- probably due to # e.g. a mismatched typevar, which each individual argument is okay with. - return "" + raise TypeCheckError("") def _remove_typing(x):