From cef4ce9d0ad8d461b81c7ae8cd642ccd1a2e3c51 Mon Sep 17 00:00:00 2001 From: Chris Lalancette Date: Mon, 17 Jun 2024 14:54:31 -0400 Subject: [PATCH] Make rclpy initialization context-manager aware. (#1298) * Make rclpy initialization context-manager aware. This PR does two somewhat controversial things: 1. It switches to an intermediate object for initialization. That way we can use context managers, but actually properly clean up, including uninstalling signal handlers. We also switch to tracking node resources in the context, so that when the context goes away, all nodes associated with it are automatically destroyed. 2. It changes the context manager implementation of Context so that it warns if the context did not have init() called before entering the context. This definitely changes the semantics, but as it stands that initialization doesn't make sense because it can't take arguments. I think this change is warranted, though we may have to search through documentation and examples to make sure this doesn't break anything. Signed-off-by: Chris Lalancette Co-authored-by: Shane Loretz --- rclpy/rclpy/__init__.py | 55 ++++++++++++++++++---- rclpy/rclpy/context.py | 79 ++++++++++++++++++++++++++------ rclpy/rclpy/node.py | 4 ++ rclpy/test/test_context.py | 2 + rclpy/test/test_init_shutdown.py | 6 +++ 5 files changed, 121 insertions(+), 25 deletions(-) diff --git a/rclpy/rclpy/__init__.py b/rclpy/rclpy/__init__.py index 6e94dce51..d53d66797 100644 --- a/rclpy/rclpy/__init__.py +++ b/rclpy/rclpy/__init__.py @@ -40,8 +40,10 @@ This will invalidate all entities derived from the context. """ +from types import TracebackType from typing import List from typing import Optional +from typing import Type from typing import TYPE_CHECKING from rclpy.context import Context @@ -62,13 +64,52 @@ from rclpy.node import Node # noqa: F401 +class InitContextManager: + """ + A proxy object for initialization. + + One of these is returned when calling `rclpy.init`, and can be used with context managers to + properly cleanup after initialization. + """ + + def __init__(self, + args: Optional[List[str]], + context: Optional[Context], + domain_id: Optional[int], + signal_handler_options: Optional[SignalHandlerOptions]) -> None: + self.context = get_default_context() if context is None else context + if signal_handler_options is None: + if context is None or context is get_default_context(): + signal_handler_options = SignalHandlerOptions.ALL + else: + signal_handler_options = SignalHandlerOptions.NO + + if signal_handler_options == SignalHandlerOptions.NO: + self.installed_signal_handlers = False + else: + self.installed_signal_handlers = True + install_signal_handlers(signal_handler_options) + self.context.init(args, domain_id=domain_id) + + def __enter__(self) -> 'InitContextManager': + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + shutdown(context=self.context, uninstall_handlers=self.installed_signal_handlers) + + def init( *, args: Optional[List[str]] = None, context: Optional[Context] = None, domain_id: Optional[int] = None, signal_handler_options: Optional[SignalHandlerOptions] = None, -) -> None: +) -> InitContextManager: """ Initialize ROS communications for a given context. @@ -78,15 +119,9 @@ def init( :param domain_id: ROS domain id. :param signal_handler_options: Indicate which signal handlers to install. If `None`, SIGINT and SIGTERM will be installed when initializing the default context. + :return: an InitContextManager that can be used with Python context managers to cleanup. """ - context = get_default_context() if context is None else context - if signal_handler_options is None: - if context is None or context is get_default_context(): - signal_handler_options = SignalHandlerOptions.ALL - else: - signal_handler_options = SignalHandlerOptions.NO - install_signal_handlers(signal_handler_options) - return context.init(args, domain_id=domain_id) + return InitContextManager(args, context, domain_id, signal_handler_options) # The global spin functions need an executor to do the work @@ -125,7 +160,7 @@ def shutdown( :param uninstall_handlers: If `None`, signal handlers will be uninstalled when shutting down the default context. If `True`, signal handlers will be uninstalled. - If not, signal handlers won't be uninstalled. + If `False`, signal handlers won't be uninstalled. """ _shutdown(context=context) if ( diff --git a/rclpy/rclpy/context.py b/rclpy/rclpy/context.py index c3a27953e..5563b09b1 100644 --- a/rclpy/rclpy/context.py +++ b/rclpy/rclpy/context.py @@ -22,11 +22,16 @@ from typing import Optional from typing import Protocol from typing import Type +from typing import TYPE_CHECKING from typing import Union -from weakref import WeakMethod +import warnings +import weakref from rclpy.destroyable import DestroyableType +if TYPE_CHECKING: + from rclpy.node import Node + class ContextHandle(DestroyableType, Protocol): @@ -60,10 +65,11 @@ class Context(ContextManager['Context']): """ def __init__(self) -> None: - self._lock = threading.Lock() - self._callbacks: List[Union['WeakMethod[MethodType]', Callable[[], None]]] = [] + self._lock = threading.RLock() + self._callbacks: List[Union['weakref.WeakMethod[MethodType]', Callable[[], None]]] = [] self._logging_initialized = False self.__context: Optional[ContextHandle] = None + self.__node_weak_ref_list: List[weakref.ReferenceType['Node']] = [] @property def handle(self) -> Optional[ContextHandle]: @@ -82,6 +88,10 @@ def init(self, Initialize ROS communications for a given context. :param args: List of command line arguments. + :param initialize_logging: Whether to initialize logging for the whole process. + The default is to initialize logging. + :param domain_id: Which domain ID to use for this context. + If None (the default), domain ID 0 is used. """ # imported locally to avoid loading extensions on module import from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy @@ -106,6 +116,36 @@ def init(self, _rclpy.rclpy_logging_configure(self.__context) self._logging_initialized = True + def track_node(self, node: 'Node') -> None: + """ + Track a Node associated with this Context. + + When the Context is destroyed, it will destroy every Node it tracks. + + :param node: The node to take a weak reference to. + """ + with self._lock: + self.__node_weak_ref_list.append(weakref.ref(node)) + + def untrack_node(self, node: 'Node') -> None: + """ + Stop tracking a Node associated with this Context. + + If a Node is destroyed before the context, we no longer need to track it for destruction of + the Context, so remove it here. + """ + with self._lock: + for index, weak_node in enumerate(self.__node_weak_ref_list): + node_in_list = weak_node() + if node_in_list is node: + found_index = index + break + else: + # Odd that we didn't find the node in the list, but just get out + return + + del self.__node_weak_ref_list[found_index] + def ok(self) -> bool: """Check if context hasn't been shut down.""" with self._lock: @@ -121,14 +161,22 @@ def _call_on_shutdown_callbacks(self) -> None: callback() self._callbacks = [] + def _cleanup(self) -> None: + for weak_node in self.__node_weak_ref_list: + node = weak_node() + if node is not None: + node.destroy_node() + + self.__context.shutdown() + self._call_on_shutdown_callbacks() + self._logging_fini() + def shutdown(self) -> None: """Shutdown this context.""" if self.__context is None: raise RuntimeError('Context must be initialized before it can be shutdown') with self.__context, self._lock: - self.__context.shutdown() - self._call_on_shutdown_callbacks() - self._logging_fini() + self._cleanup() def try_shutdown(self) -> None: """Shutdown this context, if not already shutdown.""" @@ -136,11 +184,9 @@ def try_shutdown(self) -> None: return with self.__context, self._lock: if self.__context.ok(): - self.__context.shutdown() - self._call_on_shutdown_callbacks() - self._logging_fini() + self._cleanup() - def _remove_callback(self, weak_method: 'WeakMethod[MethodType]') -> None: + def _remove_callback(self, weak_method: 'weakref.WeakMethod[MethodType]') -> None: self._callbacks.remove(weak_method) def on_shutdown(self, callback: Callable[[], None]) -> None: @@ -151,7 +197,7 @@ def on_shutdown(self, callback: Callable[[], None]) -> None: if self.__context is None: with self._lock: if ismethod(callback): - self._callbacks.append(WeakMethod(callback, self._remove_callback)) + self._callbacks.append(weakref.WeakMethod(callback, self._remove_callback)) else: self._callbacks.append(callback) return @@ -161,7 +207,7 @@ def on_shutdown(self, callback: Callable[[], None]) -> None: callback() else: if ismethod(callback): - self._callbacks.append(WeakMethod(callback, self._remove_callback)) + self._callbacks.append(weakref.WeakMethod(callback, self._remove_callback)) else: self._callbacks.append(callback) @@ -187,9 +233,12 @@ def get_domain_id(self) -> int: return self.__context.get_domain_id() def __enter__(self) -> 'Context': - # We do not accept parameters here. If one wants to customize the init() call, - # they would have to call it manually and not use the ContextManager convenience - self.init() + if self.__context is None: + # init() hasn't been called yet; for backwards compatibility, initialize and warn + warnings.warn('init() must be called on a Context before using it in a Python context ' + 'manager. Calling init() with no arguments, this usage is deprecated') + self.init() + return self def __exit__( diff --git a/rclpy/rclpy/node.py b/rclpy/rclpy/node.py index 6a3b4d843..3b7fd37b9 100644 --- a/rclpy/rclpy/node.py +++ b/rclpy/rclpy/node.py @@ -247,6 +247,8 @@ def __init__( self._type_description_service = TypeDescriptionService(self) + self._context.track_node(self) + @property def publishers(self) -> Iterator[Publisher]: """Get publishers that have been created on this node.""" @@ -1944,6 +1946,8 @@ def destroy_node(self): * :func:`create_guard_condition` """ + self._context.untrack_node(self) + # Drop extra reference to parameter event publisher. # It will be destroyed with other publishers below. self._parameter_event_publisher = None diff --git a/rclpy/test/test_context.py b/rclpy/test/test_context.py index f1335ef9f..7831597fd 100644 --- a/rclpy/test/test_context.py +++ b/rclpy/test/test_context.py @@ -61,6 +61,8 @@ def test_context_manager(): assert not context.ok(), 'the context should not be ok() before init() is called' + context.init() + with context as the_context: # Make sure the correct instance is returned assert the_context is context diff --git a/rclpy/test/test_init_shutdown.py b/rclpy/test/test_init_shutdown.py index 633142dec..43a514448 100644 --- a/rclpy/test/test_init_shutdown.py +++ b/rclpy/test/test_init_shutdown.py @@ -109,3 +109,9 @@ def test_signal_handlers(): def test_init_with_invalid_domain_id(): with pytest.raises(RuntimeError): rclpy.init(domain_id=-1) + + +def test_managed_init(): + with rclpy.init(domain_id=123) as init: + assert init.context.get_domain_id() == 123 + assert init.context.ok()