Skip to content

Commit

Permalink
Add types to duration.py (ros2#1233)
Browse files Browse the repository at this point in the history
* Add types to logging_service.py (ros2#1227)

* add types to logging_service

* Add types to duration.py

* Add newlines for class definintions

* update type alias name

* Update to use Protocols

Signed-off-by: Michael Carlstrom <rmc@carlstrom.com>
  • Loading branch information
InvincibleRMC committed Mar 24, 2024
1 parent 500893d commit e2ab502
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions rclpy/rclpy/duration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Protocol, Union

import builtin_interfaces.msg
from rclpy.constants import S_TO_NS
from rclpy.impl.implementation_singleton import rclpy_implementation as _rclpy


class DurationType(Protocol):
"""Type alias of _rclpy.rcl_duration_t."""

nanoseconds: int


class Duration:
"""A period between two time points, with nanosecond precision."""

def __init__(self, *, seconds=0, nanoseconds=0):
def __init__(self, *, seconds: Union[int, float] = 0, nanoseconds: int = 0):
"""
Create an instance of :class:`Duration`, combined from given seconds and nanoseconds.
Expand All @@ -33,51 +41,51 @@ def __init__(self, *, seconds=0, nanoseconds=0):
# pybind11 would raise TypeError, but we want OverflowError
raise OverflowError(
'Total nanoseconds value is too large to store in C duration.')
self._duration_handle = _rclpy.rcl_duration_t(total_nanoseconds)
self._duration_handle: DurationType = _rclpy.rcl_duration_t(total_nanoseconds)

@property
def nanoseconds(self):
def nanoseconds(self) -> int:
return self._duration_handle.nanoseconds

def __repr__(self):
def __repr__(self) -> str:
return 'Duration(nanoseconds={0})'.format(self.nanoseconds)

def __str__(self):
def __str__(self) -> str:
if self == Infinite:
return 'Infinite'
return f'{self.nanoseconds} nanoseconds'

def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.nanoseconds == other.nanoseconds
return NotImplemented

def __ne__(self, other):
def __ne__(self, other: object) -> bool:
if isinstance(other, Duration):
return not self.__eq__(other)
return NotImplemented

def __lt__(self, other):
def __lt__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.nanoseconds < other.nanoseconds
return NotImplemented

def __le__(self, other):
def __le__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.nanoseconds <= other.nanoseconds
return NotImplemented

def __gt__(self, other):
def __gt__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.nanoseconds > other.nanoseconds
return NotImplemented

def __ge__(self, other):
def __ge__(self, other: object) -> bool:
if isinstance(other, Duration):
return self.nanoseconds >= other.nanoseconds
return NotImplemented

def to_msg(self):
def to_msg(self) -> builtin_interfaces.msg.Duration:
"""
Get duration as :class:`builtin_interfaces.msg.Duration`.
Expand All @@ -88,7 +96,7 @@ def to_msg(self):
return builtin_interfaces.msg.Duration(sec=seconds, nanosec=nanoseconds)

@classmethod
def from_msg(cls, msg):
def from_msg(cls, msg: builtin_interfaces.msg.Duration) -> 'Duration':
"""
Create an instance of :class:`Duration` from a duration message.
Expand All @@ -98,7 +106,7 @@ def from_msg(cls, msg):
raise TypeError('Must pass a builtin_interfaces.msg.Duration object')
return cls(seconds=msg.sec, nanoseconds=msg.nanosec)

def get_c_duration(self):
def get_c_duration(self) -> DurationType:
return self._duration_handle


Expand Down

0 comments on commit e2ab502

Please sign in to comment.