Skip to content

Commit 4cb9f7f

Browse files
committed
Step 5
1 parent 21f7551 commit 4cb9f7f

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

.mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ disallow_any_generics = True
2020
# These next few are various gradations of forcing use of type annotations
2121
disallow_untyped_calls = True
2222
disallow_incomplete_defs = True
23-
; disallow_untyped_defs = True
23+
disallow_untyped_defs = True
2424

2525
# This one isn't too hard to get passing, but return on investment is lower
2626
no_implicit_reexport = True

django_fsm/__init__.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
IntegerField = models.IntegerField[int, int]
4949
ForeignKey = models.ForeignKey[Any, Any]
5050

51+
_StateValue = str | int
5152
_Instance = models.Model # TODO: use real type
5253
_ToDo = Any # TODO: use real type
5354
else:
@@ -82,10 +83,10 @@ class ConcurrentTransition(Exception):
8283
class Transition:
8384
def __init__(
8485
self,
85-
method: Callable[..., str | int | None],
86-
source: str | int | Sequence[str | int] | State,
87-
target: str | int,
88-
on_error: str | int | None,
86+
method: Callable[..., _StateValue | Any],
87+
source: _StateValue | Sequence[_StateValue] | State,
88+
target: _StateValue,
89+
on_error: _StateValue | None,
8990
conditions: list[Callable[[_Instance], bool]],
9091
permission: str | Callable[[_Instance, UserWithPermissions], bool] | None,
9192
custom: dict[str, _StrOrPromise],
@@ -402,7 +403,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
402403
if not issubclass(sender, self.base_cls):
403404
return
404405

405-
def is_field_transition_method(attr):
406+
def is_field_transition_method(attr: _ToDo) -> bool:
406407
return (
407408
(inspect.ismethod(attr) or inspect.isfunction(attr))
408409
and hasattr(attr, "_django_fsm")
@@ -489,7 +490,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
489490
def state_fields(self) -> Iterable[Any]:
490491
return filter(lambda field: isinstance(field, FSMFieldMixin), self._meta.fields)
491492

492-
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update):
493+
def _do_update(self, base_qs, using, pk_val, values, update_fields, forced_update): # type: ignore[no-untyped-def]
493494
# _do_update is called once for each model class in the inheritance hierarchy.
494495
# We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
495496

@@ -533,21 +534,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
533534

534535
def transition(
535536
field: FSMFieldMixin,
536-
source: str | int | Sequence[str | int] | State = "*",
537+
source: str | int | Sequence[str | int] = "*",
537538
target: str | int | State | None = None,
538539
on_error: str | int | None = None,
539540
conditions: list[Callable[[Any], bool]] = [],
540541
permission: str | Callable[[models.Model, UserWithPermissions], bool] | None = None,
541542
custom: dict[str, _StrOrPromise] = {},
542-
) -> _ToDo:
543+
) -> Callable[[Any], Any]:
543544
"""
544545
Method decorator to mark allowed transitions.
545546
546547
Set target to None if current state needs to be validated and
547548
has not changed after the function call.
548549
"""
549550

550-
def inner_transition(func):
551+
def inner_transition(func: _ToDo) -> _ToDo:
551552
wrapper_installed, fsm_meta = True, getattr(func, "_django_fsm", None)
552553
if not fsm_meta:
553554
wrapper_installed = False
@@ -608,15 +609,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
608609

609610

610611
class State:
611-
def get_state(self, model, transition, result, args=[], kwargs={}):
612+
def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
612613
raise NotImplementedError
613614

614615

615616
class RETURN_VALUE(State):
616617
def __init__(self, *allowed_states: Sequence[str | int]) -> None:
617618
self.allowed_states = allowed_states if allowed_states else None
618619

619-
def get_state(self, model, transition, result, args=[], kwargs={}):
620+
def get_state(self, model: _Model, transition: Transition, result: Any, args: Any = [], kwargs: Any = {}) -> _ToDo:
620621
if self.allowed_states is not None:
621622
if result not in self.allowed_states:
622623
raise InvalidResultState(f"{result} is not in list of allowed states\n{self.allowed_states}")
@@ -628,7 +629,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
628629
self.func = func
629630
self.allowed_states = states
630631

631-
def get_state(self, model, transition, result, args=[], kwargs={}):
632+
def get_state(
633+
self, model: _Model, transition: Transition, result: _StateValue | Any, args: Any = [], kwargs: Any = {}
634+
) -> _ToDo:
632635
result_state = self.func(model, *args, **kwargs)
633636
if self.allowed_states is not None:
634637
if result_state not in self.allowed_states:

0 commit comments

Comments
 (0)