Skip to content

Commit

Permalink
Fixes #254.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Oct 11, 2024
1 parent f4ca4c9 commit 9a019cf
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions jaxtyping/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,17 @@
import itertools as it
import sys
import warnings
from typing import Any, get_args, get_origin, get_type_hints, NoReturn, overload
from collections.abc import Callable
from typing import (
Any,
get_args,
get_origin,
get_type_hints,
NoReturn,
overload,
ParamSpec,
TypeVar,
)

from jaxtyping import AbstractArray

Expand All @@ -33,6 +43,10 @@
from ._storage import pop_shape_memo, push_shape_memo, shape_str


_Params = ParamSpec("_Params")
_Return = TypeVar("_Return")


class _Sentinel:
def __repr__(self):
return "sentinel"
Expand All @@ -43,12 +57,17 @@ def __repr__(self):


@overload
def jaxtyped(*, typechecker=_sentinel):
def jaxtyped(
*,
typechecker=_sentinel,
) -> Callable[[Callable[_Params, _Return]], Callable[_Params, _Return]]:
...


@overload
def jaxtyped(fn, *, typechecker=_sentinel):
def jaxtyped(
fn: Callable[_Params, _Return], *, typechecker=_sentinel
) -> Callable[_Params, _Return]:
...


Expand Down

0 comments on commit 9a019cf

Please sign in to comment.