48
48
IntegerField = models .IntegerField [int , int ]
49
49
ForeignKey = models .ForeignKey [Any , Any ]
50
50
51
+ _StateValue = str | int
51
52
_Instance = models .Model # TODO: use real type
52
53
_ToDo = Any # TODO: use real type
53
54
else :
@@ -82,10 +83,10 @@ class ConcurrentTransition(Exception):
82
83
class Transition :
83
84
def __init__ (
84
85
self ,
85
- method : Callable [..., str | int | None ],
86
- source : str | int | Sequence [str | int ] | State ,
87
- target : str | int ,
88
- on_error : str | int | None ,
86
+ method : Callable [..., _StateValue | Any ],
87
+ source : _StateValue | Sequence [_StateValue ] | State ,
88
+ target : _StateValue ,
89
+ on_error : _StateValue | None ,
89
90
conditions : list [Callable [[_Instance ], bool ]],
90
91
permission : str | Callable [[_Instance , UserWithPermissions ], bool ] | None ,
91
92
custom : dict [str , _StrOrPromise ],
@@ -402,7 +403,7 @@ def _collect_transitions(self, *args: Any, **kwargs: Any) -> None:
402
403
if not issubclass (sender , self .base_cls ):
403
404
return
404
405
405
- def is_field_transition_method (attr ) :
406
+ def is_field_transition_method (attr : _ToDo ) -> bool :
406
407
return (
407
408
(inspect .ismethod (attr ) or inspect .isfunction (attr ))
408
409
and hasattr (attr , "_django_fsm" )
@@ -489,7 +490,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
489
490
def state_fields (self ) -> Iterable [Any ]:
490
491
return filter (lambda field : isinstance (field , FSMFieldMixin ), self ._meta .fields )
491
492
492
- def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ):
493
+ def _do_update (self , base_qs , using , pk_val , values , update_fields , forced_update ): # type: ignore[no-untyped-def]
493
494
# _do_update is called once for each model class in the inheritance hierarchy.
494
495
# We can only filter the base_qs on state fields (can be more than one!) present in this particular model.
495
496
@@ -533,21 +534,21 @@ def save(self, *args: Any, **kwargs: Any) -> None:
533
534
534
535
def transition (
535
536
field : FSMFieldMixin ,
536
- source : str | int | Sequence [str | int ] | State = "*" ,
537
+ source : str | int | Sequence [str | int ] = "*" ,
537
538
target : str | int | State | None = None ,
538
539
on_error : str | int | None = None ,
539
540
conditions : list [Callable [[Any ], bool ]] = [],
540
541
permission : str | Callable [[models .Model , UserWithPermissions ], bool ] | None = None ,
541
542
custom : dict [str , _StrOrPromise ] = {},
542
- ) -> _ToDo :
543
+ ) -> Callable [[ Any ], Any ] :
543
544
"""
544
545
Method decorator to mark allowed transitions.
545
546
546
547
Set target to None if current state needs to be validated and
547
548
has not changed after the function call.
548
549
"""
549
550
550
- def inner_transition (func ) :
551
+ def inner_transition (func : _ToDo ) -> _ToDo :
551
552
wrapper_installed , fsm_meta = True , getattr (func , "_django_fsm" , None )
552
553
if not fsm_meta :
553
554
wrapper_installed = False
@@ -608,15 +609,15 @@ def has_transition_perm(bound_method: _ToDo, user: UserWithPermissions) -> bool:
608
609
609
610
610
611
class State :
611
- def get_state (self , model , transition , result , args = [], kwargs = {}):
612
+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
612
613
raise NotImplementedError
613
614
614
615
615
616
class RETURN_VALUE (State ):
616
617
def __init__ (self , * allowed_states : Sequence [str | int ]) -> None :
617
618
self .allowed_states = allowed_states if allowed_states else None
618
619
619
- def get_state (self , model , transition , result , args = [], kwargs = {}):
620
+ def get_state (self , model : _Model , transition : Transition , result : Any , args : Any = [], kwargs : Any = {}) -> _ToDo :
620
621
if self .allowed_states is not None :
621
622
if result not in self .allowed_states :
622
623
raise InvalidResultState (f"{ result } is not in list of allowed states\n { self .allowed_states } " )
@@ -628,7 +629,9 @@ def __init__(self, func: Callable[..., str | int], states: Sequence[str | int] |
628
629
self .func = func
629
630
self .allowed_states = states
630
631
631
- def get_state (self , model , transition , result , args = [], kwargs = {}):
632
+ def get_state (
633
+ self , model : _Model , transition : Transition , result : _StateValue | Any , args : Any = [], kwargs : Any = {}
634
+ ) -> _ToDo :
632
635
result_state = self .func (model , * args , ** kwargs )
633
636
if self .allowed_states is not None :
634
637
if result_state not in self .allowed_states :
0 commit comments