From 9a019cf28f07775b782c7e4eaeb8467fc244313b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 11 Oct 2024 15:33:48 +0200 Subject: [PATCH] Fixes #254. --- jaxtyping/_decorator.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/jaxtyping/_decorator.py b/jaxtyping/_decorator.py index 2d0eac7..653e056 100644 --- a/jaxtyping/_decorator.py +++ b/jaxtyping/_decorator.py @@ -24,7 +24,17 @@ import itertools as it import sys import warnings -from typing import Any, get_args, get_origin, get_type_hints, NoReturn, overload +from collections.abc import Callable +from typing import ( + Any, + get_args, + get_origin, + get_type_hints, + NoReturn, + overload, + ParamSpec, + TypeVar, +) from jaxtyping import AbstractArray @@ -33,6 +43,10 @@ from ._storage import pop_shape_memo, push_shape_memo, shape_str +_Params = ParamSpec("_Params") +_Return = TypeVar("_Return") + + class _Sentinel: def __repr__(self): return "sentinel" @@ -43,12 +57,17 @@ def __repr__(self): @overload -def jaxtyped(*, typechecker=_sentinel): +def jaxtyped( + *, + typechecker=_sentinel, +) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]: ... @overload -def jaxtyped(fn, *, typechecker=_sentinel): +def jaxtyped( + fn: Callable[_Params, _Return], *, typechecker=_sentinel +) -> Callable[_Params, _Return]: ...