From 2f9a77148950655b1e4652bba0fbd4632c08e18d Mon Sep 17 00:00:00 2001 From: Michael Carlstrom Date: Fri, 4 Oct 2024 15:08:14 -0400 Subject: [PATCH] Adds types to Lifecycle Objects (#1338) Signed-off-by: Michael Carlstrom --- rclpy/rclpy/impl/_rclpy_pybind11.pyi | 97 ++++++++++++++--- rclpy/rclpy/lifecycle/managed_entity.py | 50 ++++++--- rclpy/rclpy/lifecycle/node.py | 138 +++++++++++++++++++----- rclpy/rclpy/lifecycle/publisher.py | 29 ++++- 4 files changed, 258 insertions(+), 56 deletions(-) diff --git a/rclpy/rclpy/impl/_rclpy_pybind11.pyi b/rclpy/rclpy/impl/_rclpy_pybind11.pyi index 679f6fa4f..f348b2519 100644 --- a/rclpy/rclpy/impl/_rclpy_pybind11.pyi +++ b/rclpy/rclpy/impl/_rclpy_pybind11.pyi @@ -16,7 +16,7 @@ from __future__ import annotations from enum import Enum, IntEnum from types import TracebackType -from typing import Any, Generic, Literal, overload, Sequence, TypedDict +from typing import Any, Generic, Literal, overload, Sequence, TypeAlias, TypedDict from rclpy.clock import JumpHandle from rclpy.clock_type import ClockType @@ -101,21 +101,23 @@ class rcl_duration_t: nanoseconds: int -class rcl_subscription_event_type_t(IntEnum): - RCL_SUBSCRIPTION_REQUESTED_DEADLINE_MISSED: int - RCL_SUBSCRIPTION_LIVELINESS_CHANGED: int - RCL_SUBSCRIPTION_REQUESTED_INCOMPATIBLE_QOS: int - RCL_SUBSCRIPTION_MESSAGE_LOST: int - RCL_SUBSCRIPTION_INCOMPATIBLE_TYPE: int - RCL_SUBSCRIPTION_MATCHED: int +class rcl_subscription_event_type_t(Enum): + _value_: int + RCL_SUBSCRIPTION_REQUESTED_DEADLINE_MISSED = ... + RCL_SUBSCRIPTION_LIVELINESS_CHANGED = ... + RCL_SUBSCRIPTION_REQUESTED_INCOMPATIBLE_QOS = ... + RCL_SUBSCRIPTION_MESSAGE_LOST = ... + RCL_SUBSCRIPTION_INCOMPATIBLE_TYPE = ... + RCL_SUBSCRIPTION_MATCHED = ... -class rcl_publisher_event_type_t(IntEnum): - RCL_PUBLISHER_OFFERED_DEADLINE_MISSED: int - RCL_PUBLISHER_LIVELINESS_LOST: int - RCL_PUBLISHER_OFFERED_INCOMPATIBLE_QOS: int - RCL_PUBLISHER_INCOMPATIBLE_TYPE: int - RCL_PUBLISHER_MATCHED: int +class rcl_publisher_event_type_t(Enum): + _value_: int + RCL_PUBLISHER_OFFERED_DEADLINE_MISSED = ... + RCL_PUBLISHER_LIVELINESS_LOST = ... + RCL_PUBLISHER_OFFERED_INCOMPATIBLE_QOS = ... + RCL_PUBLISHER_INCOMPATIBLE_TYPE = ... + RCL_PUBLISHER_MATCHED = ... class EventHandle(Destroyable): @@ -135,6 +137,73 @@ class EventHandle(Destroyable): """Get pending data from a ready event.""" +LifecycleStateMachineState: TypeAlias = tuple[int, str] + + +class LifecycleStateMachine(Destroyable): + + def __init__(self, node: Node, enable_com_interface: bool) -> None: ... + + @property + def initialized(self) -> bool: + """Check if state machine is initialized.""" + + @property + def current_state(self) -> LifecycleStateMachineState: + """Get the current state machine state.""" + + @property + def available_states(self) -> list[LifecycleStateMachineState]: + """Get the available states.""" + + @property + def available_transitions(self) -> list[tuple[int, str, int, str, int, str]]: + """Get the available transitions.""" + + @property + def transition_graph(self) -> list[tuple[int, str, int, str, int, str]]: + """Get the transition graph.""" + + def get_transition_by_label(self, label: str) -> int: + """Get the transition id from a transition label.""" + + def trigger_transition_by_id(self, transition_id: int, publish_update: bool) -> None: + """Trigger a transition by transition id.""" + + def trigger_transition_by_label(self, label: str, publish_update: bool) -> None: + """Trigger a transition by label.""" + + @property + def service_change_state(self) -> Service: + """Get the change state service.""" + + @property + def service_get_state(self) -> Service: + """Get the get state service.""" + + @property + def service_get_available_states(self) -> Service: + """Get the get available states service.""" + + @property + def service_get_available_transitions(self) -> Service: + """Get the get available transitions service.""" + + @property + def service_get_transition_graph(self) -> Service: + """Get the get transition graph service.""" + + +class TransitionCallbackReturnType(Enum): + _value_: int + SUCCESS = ... + FAILURE = ... + ERROR = ... + + def to_label(self) -> str: + """Convert the transition callback return code to a transition label.""" + + class GuardCondition(Destroyable): def __init__(self, context: Context) -> None: ... diff --git a/rclpy/rclpy/lifecycle/managed_entity.py b/rclpy/rclpy/lifecycle/managed_entity.py index 93aaf9b62..a3de05a7a 100644 --- a/rclpy/rclpy/lifecycle/managed_entity.py +++ b/rclpy/rclpy/lifecycle/managed_entity.py @@ -13,36 +13,41 @@ # limitations under the License. from functools import wraps +from typing import Any, Callable, Dict, List, Optional, overload, TYPE_CHECKING, Union from ..impl.implementation_singleton import rclpy_implementation as _rclpy +if TYPE_CHECKING: + from typing import TypeAlias + from rclpy.lifecycle.node import LifecycleState -TransitionCallbackReturn = _rclpy.TransitionCallbackReturnType + +TransitionCallbackReturn: 'TypeAlias' = _rclpy.TransitionCallbackReturnType class ManagedEntity: - def on_configure(self, state) -> TransitionCallbackReturn: + def on_configure(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle configure transition request.""" return TransitionCallbackReturn.SUCCESS - def on_cleanup(self, state) -> TransitionCallbackReturn: + def on_cleanup(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle cleanup transition request.""" return TransitionCallbackReturn.SUCCESS - def on_shutdown(self, state) -> TransitionCallbackReturn: + def on_shutdown(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle shutdown transition request.""" return TransitionCallbackReturn.SUCCESS - def on_activate(self, state) -> TransitionCallbackReturn: + def on_activate(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle activate transition request.""" return TransitionCallbackReturn.SUCCESS - def on_deactivate(self, state) -> TransitionCallbackReturn: + def on_deactivate(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle deactivate transition request.""" return TransitionCallbackReturn.SUCCESS - def on_error(self, state) -> TransitionCallbackReturn: + def on_error(self, state: 'LifecycleState') -> TransitionCallbackReturn: """Handle error transition request.""" return TransitionCallbackReturn.SUCCESS @@ -50,26 +55,43 @@ def on_error(self, state) -> TransitionCallbackReturn: class SimpleManagedEntity(ManagedEntity): """A simple managed entity that only sets a flag when activated/deactivated.""" - def __init__(self): + def __init__(self) -> None: self._enabled = False - def on_activate(self, state) -> TransitionCallbackReturn: + def on_activate(self, state: 'LifecycleState') -> TransitionCallbackReturn: self._enabled = True return TransitionCallbackReturn.SUCCESS - def on_deactivate(self, state) -> TransitionCallbackReturn: + def on_deactivate(self, state: 'LifecycleState') -> TransitionCallbackReturn: self._enabled = False return TransitionCallbackReturn.SUCCESS @property - def is_activated(self): + def is_activated(self) -> bool: return self._enabled @staticmethod - def when_enabled(wrapped=None, *, when_not_enabled=None): - def decorator(wrapped): + @overload + def when_enabled(wrapped: None, *, + when_not_enabled: Optional[Callable[..., None]] = None + ) -> Callable[[Callable[..., None]], Callable[..., None]]: ... + + @staticmethod + @overload + def when_enabled(wrapped: Callable[..., None], *, + when_not_enabled: Optional[Callable[..., None]] = None + ) -> Callable[..., None]: ... + + @staticmethod + def when_enabled(wrapped: Optional[Callable[..., None]] = None, *, + when_not_enabled: Optional[Callable[..., None]] = None) -> Union[ + Callable[..., None], + Callable[[Callable[..., None]], Callable[..., None]] + ]: + def decorator(wrapped: Callable[..., None]) -> Callable[..., None]: @wraps(wrapped) - def only_when_enabled_wrapper(self: SimpleManagedEntity, *args, **kwargs): + def only_when_enabled_wrapper(self: SimpleManagedEntity, *args: List[Any], + **kwargs: Dict[str, Any]) -> None: if not self._enabled: if when_not_enabled is not None: when_not_enabled() diff --git a/rclpy/rclpy/lifecycle/node.py b/rclpy/rclpy/lifecycle/node.py index 6528007d5..7fd424e07 100644 --- a/rclpy/rclpy/lifecycle/node.py +++ b/rclpy/rclpy/lifecycle/node.py @@ -12,11 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from typing import Callable from typing import Dict +from typing import List +from typing import Literal from typing import NamedTuple from typing import Optional from typing import Set +from typing import Type +from typing import TYPE_CHECKING +from typing import TypedDict +from typing import Union import lifecycle_msgs.msg import lifecycle_msgs.srv @@ -26,13 +33,25 @@ from rclpy.node import Node from rclpy.qos import QoSProfile from rclpy.service import Service -from rclpy.type_support import check_is_valid_srv_type +from rclpy.type_support import check_is_valid_srv_type, MsgT from .managed_entity import ManagedEntity from .publisher import LifecyclePublisher +if TYPE_CHECKING: + from typing import TypeAlias + from typing import Unpack -TransitionCallbackReturn = _rclpy.TransitionCallbackReturnType + from rclpy.context import Context + from rclpy.parameter import Parameter + from rclpy.qos_overriding_options import QoSOverridingOptions + from rclpy.event_handler import PublisherEventCallbacks + +TransitionCallbackReturn: 'TypeAlias' = _rclpy.TransitionCallbackReturnType + + +CallbackNames = Literal['on_configure', 'on_cleanup', 'on_shutdown', 'on_activate', + 'on_deactivate', 'on_error'] class LifecycleState(NamedTuple): @@ -40,6 +59,12 @@ class LifecycleState(NamedTuple): state_id: int +class CreateLifecyclePublisherArgs(TypedDict): + callback_group: Optional[CallbackGroup] + event_callbacks: 'Optional[PublisherEventCallbacks]' + qos_overriding_options: 'Optional[QoSOverridingOptions]' + + class LifecycleNodeMixin(ManagedEntity): """ Mixin class to share as most code as possible between `Node` and `LifecycleNode`. @@ -62,6 +87,10 @@ def __init__( :param callback_group: Callback group that will be used by all the lifecycle node services. """ + if not isinstance(self, Node): + raise RuntimeError('LifecycleNodeMixin uses Node fields so Node needs to be' + 'in the inheritance tree.') + self._callbacks: Dict[int, Callable[[LifecycleState], TransitionCallbackReturn]] = {} # register all state machine transition callbacks self.__register_callback( @@ -147,13 +176,13 @@ def __init__( # Extend base class list of services, so they are added to the executor when spinning. self._services.extend(lifecycle_services) - def trigger_configure(self): + def trigger_configure(self) -> TransitionCallbackReturn: return self.__change_state(lifecycle_msgs.msg.Transition.TRANSITION_CONFIGURE) - def trigger_cleanup(self): + def trigger_cleanup(self) -> TransitionCallbackReturn: return self.__change_state(lifecycle_msgs.msg.Transition.TRANSITION_CLEANUP) - def trigger_shutdown(self): + def trigger_shutdown(self) -> TransitionCallbackReturn: current_state = self._state_machine.current_state[1] if current_state == 'unconfigured': return self.__change_state( @@ -165,22 +194,37 @@ def trigger_shutdown(self): return self.__change_state(lifecycle_msgs.msg.Transition.TRANSITION_ACTIVE_SHUTDOWN) raise _rclpy.RCLError('Shutdown transtion not possible') - def trigger_activate(self): + def trigger_activate(self) -> TransitionCallbackReturn: return self.__change_state(lifecycle_msgs.msg.Transition.TRANSITION_ACTIVATE) - def trigger_deactivate(self): + def trigger_deactivate(self) -> TransitionCallbackReturn: return self.__change_state(lifecycle_msgs.msg.Transition.TRANSITION_DEACTIVATE) - def add_managed_entity(self, entity: ManagedEntity): + def add_managed_entity(self, entity: ManagedEntity) -> None: if not isinstance(entity, ManagedEntity): raise TypeError('Expected a rclpy.lifecycle.ManagedEntity instance.') self._managed_entities.add(entity) - def __transition_callback_impl(self, callback_name: str, state: LifecycleState): + def __transition_callback_impl(self, callback_name: CallbackNames, + state: LifecycleState) -> TransitionCallbackReturn: for entity in self._managed_entities: - cb = getattr(entity, callback_name) + if callback_name == 'on_activate': + cb = entity.on_activate + elif callback_name == 'on_cleanup': + cb = entity.on_cleanup + elif callback_name == 'on_configure': + cb = entity.on_configure + elif callback_name == 'on_deactivate': + cb = entity.on_deactivate + elif callback_name == 'on_error': + cb = entity.on_error + elif callback_name == 'on_shutdown': + cb = entity.on_shutdown + else: + raise ValueError(f'Not valid callback name "{callback_name}" given.') + ret = cb(state) - if not isinstance(ret, TransitionCallbackReturn): + if not isinstance(ret, _rclpy.TransitionCallbackReturnType): raise TypeError( f'{callback_name}() return value of class {type(entity)} should be' ' `TransitionCallbackReturn`.\n' @@ -273,30 +317,52 @@ def on_error(self, state: LifecycleState) -> TransitionCallbackReturn: """ return self.__transition_callback_impl('on_error', state) - def create_lifecycle_publisher(self, *args, **kwargs): + def create_lifecycle_publisher( + self, + msg_type: Type[MsgT], + topic: str, + qos_profile: Union[QoSProfile, int], + *, + publisher_class: None = None, + **kwargs: 'Unpack[CreateLifecyclePublisherArgs]' + ) -> LifecyclePublisher[MsgT]: # TODO(ivanpauno): Should we override lifecycle publisher? # There is an issue with python using the overridden method # when creating publishers for builitin publishers (like parameters events). # We could override them after init, similar to what we do to override publish() # in LifecycleNode. # Having both options seem fine. - if 'publisher_class' in kwargs: + if not isinstance(self, Node): + raise RuntimeError('LifecycleNodeMixin uses Node fields so Node needs to be' + 'in the inheritance tree.') + + if publisher_class: raise TypeError( "create_publisher() got an unexpected keyword argument 'publisher_class'") - pub = Node.create_publisher(self, *args, **kwargs, publisher_class=LifecyclePublisher) + pub = Node.create_publisher(self, msg_type, topic, qos_profile, + publisher_class=LifecyclePublisher, + **kwargs) + + if not isinstance(pub, LifecyclePublisher): + raise RuntimeError('Node failed to create LifecyclePublisher.') + self._managed_entities.add(pub) return pub - def destroy_lifecycle_publisher(self, publisher: LifecyclePublisher): + def destroy_lifecycle_publisher(self, publisher: LifecyclePublisher[Any]) -> bool: + if not isinstance(self, Node): + raise RuntimeError('LifecycleNodeMixin uses Node fields so Node needs to be' + 'in the inheritance tree.') + try: self._managed_entities.remove(publisher) except KeyError: pass - return Node.destroy_publisher(self, publisher) + return self.destroy_publisher(publisher) def __register_callback( self, state_id: int, callback: Callable[[LifecycleState], TransitionCallbackReturn] - ) -> bool: + ) -> Literal[True]: """ Register a callback that will be triggered when transitioning to state_id. @@ -337,7 +403,7 @@ def __change_state(self, transition_id: int) -> TransitionCallbackReturn: self._state_machine.trigger_transition_by_label(error_cb_ret_code.to_label(), True) return cb_return_code - def __check_is_initialized(self): + def __check_is_initialized(self) -> None: if not self._state_machine.initialized: raise RuntimeError( 'Internal error: got service request while lifecycle state machine ' @@ -347,7 +413,7 @@ def __on_change_state( self, req: lifecycle_msgs.srv.ChangeState.Request, resp: lifecycle_msgs.srv.ChangeState.Response - ): + ) -> lifecycle_msgs.srv.ChangeState.Response: self.__check_is_initialized() transition_id = req.transition.id if req.transition.label: @@ -364,7 +430,7 @@ def __on_get_state( self, req: lifecycle_msgs.srv.GetState.Request, resp: lifecycle_msgs.srv.GetState.Response - ): + ) -> lifecycle_msgs.srv.GetState.Response: self.__check_is_initialized() resp.current_state.id, resp.current_state.label = self._state_machine.current_state return resp @@ -373,7 +439,7 @@ def __on_get_available_states( self, req: lifecycle_msgs.srv.GetAvailableStates.Request, resp: lifecycle_msgs.srv.GetAvailableStates.Response - ): + ) -> lifecycle_msgs.srv.GetAvailableStates.Response: self.__check_is_initialized() for state_id, label in self._state_machine.available_states: resp.available_states.append(lifecycle_msgs.msg.State(id=state_id, label=label)) @@ -383,7 +449,7 @@ def __on_get_available_transitions( self, req: lifecycle_msgs.srv.GetAvailableTransitions.Request, resp: lifecycle_msgs.srv.GetAvailableTransitions.Response - ): + ) -> lifecycle_msgs.srv.GetAvailableTransitions.Response: self.__check_is_initialized() for transition_description in self._state_machine.available_transitions: transition_id, transition_label, start_id, start_label, goal_id, goal_label = \ @@ -402,7 +468,7 @@ def __on_get_transition_graph( self, req: lifecycle_msgs.srv.GetAvailableTransitions.Request, resp: lifecycle_msgs.srv.GetAvailableTransitions.Response - ): + ) -> lifecycle_msgs.srv.GetAvailableTransitions.Response: self.__check_is_initialized() for transition_description in self._state_machine.transition_graph: transition_id, transition_label, start_id, start_label, goal_id, goal_label = \ @@ -418,6 +484,19 @@ def __on_get_transition_graph( return resp +class LifecycleNodeArgs(TypedDict): + context: 'Optional[Context]' + cli_args: Optional[List[str]] + namespace: Optional[str] + use_global_arguments: bool + enable_rosout: bool + start_parameter_services: bool + parameter_overrides: 'Optional[List[Parameter]]' + allow_undeclared_parameters: bool + automatically_declare_parameters_from_overrides: bool + enable_logger_service: bool + + class LifecycleNode(LifecycleNodeMixin, Node): """ A ROS 2 managed node. @@ -426,14 +505,23 @@ class LifecycleNode(LifecycleNodeMixin, Node): Methods in LifecycleNodeMixin override the ones in Node. """ - def __init__(self, node_name, *, enable_communication_interface: bool = True, **kwargs): + def __init__( + self, + node_name: str, + *, + enable_communication_interface: bool = True, + **kwargs: 'Unpack[LifecycleNodeArgs]', + ) -> None: """ Create a lifecycle node. See rclpy.lifecycle.LifecycleNodeMixin.__init__() and rclpy.node.Node() for the documentation of each parameter. """ - Node.__init__(self, node_name, **kwargs) + Node.__init__( + self, + node_name, + **kwargs) LifecycleNodeMixin.__init__( self, enable_communication_interface=enable_communication_interface) diff --git a/rclpy/rclpy/lifecycle/publisher.py b/rclpy/rclpy/lifecycle/publisher.py index 63999fab2..af6b33572 100644 --- a/rclpy/rclpy/lifecycle/publisher.py +++ b/rclpy/rclpy/lifecycle/publisher.py @@ -12,18 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from __future__ import annotations +from typing import Generic, Tuple, Type, TYPE_CHECKING, TypedDict, Union + +from rclpy.callback_groups import CallbackGroup +from rclpy.event_handler import PublisherEventCallbacks +from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy from rclpy.publisher import Publisher +from rclpy.qos import QoSProfile from rclpy.type_support import MsgT from .managed_entity import SimpleManagedEntity +if TYPE_CHECKING: + from typing import TypeAlias, Unpack + LifecyclePublisherArgs: TypeAlias = Tuple[_rclpy.Publisher[MsgT], Type[MsgT], str, QoSProfile, + PublisherEventCallbacks, CallbackGroup] + + class LifecyclePublisherKWArgs(TypedDict, Generic[MsgT]): + publisher_impl: _rclpy.Publisher[MsgT] + msg_type: Type[MsgT] + topic: str + qos_profile: QoSProfile + event_callbacks: PublisherEventCallbacks + callback_group: CallbackGroup + -class LifecyclePublisher(SimpleManagedEntity, Publisher): +class LifecyclePublisher(SimpleManagedEntity, Publisher[MsgT]): """Managed publisher entity.""" - def __init__(self, *args, **kwargs): + def __init__( + self, + *args: 'Unpack[LifecyclePublisherArgs]', + **kwargs: 'Unpack[LifecyclePublisherKWArgs[MsgT]]' + ) -> None: SimpleManagedEntity.__init__(self) Publisher.__init__(self, *args, **kwargs)