diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml deleted file mode 100644 index a204f18e..00000000 --- a/.github/workflows/black.yml +++ /dev/null @@ -1,17 +0,0 @@ -name: Black formatting - -on: - push: - branches: - - main - pull_request: - branches: - - main -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: psf/black@stable - with: - src: "spot_wrapper" diff --git a/.github/workflows/util_pre-commit.yml b/.github/workflows/util_pre-commit.yml new file mode 100644 index 00000000..853b6dec --- /dev/null +++ b/.github/workflows/util_pre-commit.yml @@ -0,0 +1,20 @@ +# Copyright (c) 2023 Boston Dynamics AI Institute, Inc. All rights reserved. + +name: Util - Pre-Commit Runner + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + pre-commit: + name: util_pre-commit + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + - uses: pre-commit/action@v3.0.0 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..0424ec83 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +# Copyright (c) 2023 Boston Dynamics AI Institute, Inc. All rights reserved. + +repos: +- repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.0.263' + hooks: + - id: ruff + args: ['--fix', '--config', 'pyproject.toml'] +- repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3.10 + args: ['--config', 'pyproject.toml'] + verbose: true +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: check-added-large-files + - id: check-toml + - id: end-of-file-fixer diff --git a/LICENSE b/LICENSE index 4b166890..1bffbaa3 100644 --- a/LICENSE +++ b/LICENSE @@ -56,4 +56,4 @@ Copyright (c) 2020 Boston Dynamics, Inc. All rights reserved. Downloading, reproducing, distributing or otherwise using the SDK Software is subject to the terms and conditions of the Boston Dynamics Software -Development Kit License (20191101-BDSDK-SL). \ No newline at end of file +Development Kit License (20191101-BDSDK-SL). diff --git a/README.md b/README.md index e16f119d..b978eafc 100644 --- a/README.md +++ b/README.md @@ -19,3 +19,11 @@ To update requirements.txt, use ```commandline pipreqs . --force ``` + +# Contributing +This repository enforces `ruff` and `black` linting. To verify that your code will pass inspection, install `pre-commit` and run: +```bash +pre-commit install +pre-commit run --all-files +``` +The [Google Style Guide](https://google.github.io/styleguide/) is followed for default formatting. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..28efd24c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,48 @@ +[tool.ruff] +# Enable pycodestyle (`E`), Pyflakes (`F`), and import sorting (`I`) +select = ["E", "F", "I"] +ignore = [] +fixable = ["ALL"] +unfixable = [] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "docker/ros", +] +line-length = 120 +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +target-version = "py38" + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.ruff.mccabe] +max-complexity = 10 + +[tool.black] +line-length = 120 +target-version = ['py38'] +include = '\.pyi?$' +force-exclude = ''' +/( +)/ +''' +preview = true diff --git a/setup.py b/setup.py index fd82e952..b49e47c2 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name="spot_wrapper", diff --git a/spot_wrapper/cam_webrtc_client.py b/spot_wrapper/cam_webrtc_client.py index d79aeb1b..a84b793d 100644 --- a/spot_wrapper/cam_webrtc_client.py +++ b/spot_wrapper/cam_webrtc_client.py @@ -69,9 +69,7 @@ def send_sdp_answer_to_spot_cam(self, token, offer_id, sdp_answer): server_url = f"https://{self.hostname}:{self.sdp_port}/{self.sdp_filename}" payload = {"id": offer_id, "sdp": base64.b64encode(sdp_answer).decode("utf8")} - r = requests.post( - server_url, verify=self.cam_ssl_cert, json=payload, headers=headers - ) + r = requests.post(server_url, verify=self.cam_ssl_cert, json=payload, headers=headers) if r.status_code != 200: raise ValueError(r) @@ -79,7 +77,8 @@ async def start(self): # first get a token try: token = self.get_bearer_token() - except: + except Exception as e: + print(f"Could not get bearer token, mocking instead. Exception: {e}") token = self.get_bearer_token(mock=True) offer_id, sdp_offer = self.get_sdp_offer_from_spot_cam(token) @@ -101,9 +100,7 @@ async def _on_ice_connection_state_change(): print(f"ICE connection state changed to: {self.pc.iceConnectionState}") if self.pc.iceConnectionState == "checking": - self.send_sdp_answer_to_spot_cam( - token, offer_id, self.pc.localDescription.sdp.encode() - ) + self.send_sdp_answer_to_spot_cam(token, offer_id, self.pc.localDescription.sdp.encode()) @self.pc.on("track") def _on_track(track): diff --git a/spot_wrapper/cam_wrapper.py b/spot_wrapper/cam_wrapper.py index a46887ad..dc317d1c 100644 --- a/spot_wrapper/cam_wrapper.py +++ b/spot_wrapper/cam_wrapper.py @@ -52,9 +52,7 @@ class LEDPosition(enum.Enum): def __init__(self, robot: Robot, logger): self.logger = logger - self.client: LightingClient = robot.ensure_client( - LightingClient.default_service_name - ) + self.client: LightingClient = robot.ensure_client(LightingClient.default_service_name) def set_led_brightness(self, brightness): """ @@ -138,9 +136,7 @@ class CompositorWrapper: def __init__(self, robot: Robot, logger): self.logger = logger - self.client: CompositorClient = robot.ensure_client( - CompositorClient.default_service_name - ) + self.client: CompositorClient = robot.ensure_client(CompositorClient.default_service_name) def list_screens(self) -> typing.List[str]: """ @@ -210,9 +206,7 @@ class HealthWrapper: """ def __init__(self, robot, logger): - self.client: HealthClient = robot.ensure_client( - HealthClient.default_service_name - ) + self.client: HealthClient = robot.ensure_client(HealthClient.default_service_name) self.logger = logger def get_bit_status( @@ -242,10 +236,7 @@ def get_temperature(self) -> typing.Tuple[str, float]: Returns: Tuple of string and float indicating the component and its temperature in celsius """ - return [ - (composite.channel_name, composite.temperature / 1e3) - for composite in self.client.get_temperature() - ] + return [(composite.channel_name, composite.temperature / 1e3) for composite in self.client.get_temperature()] # def get_system_log(self): # """ @@ -344,9 +335,7 @@ class StreamQualityWrapper: """ def __init__(self, robot, logger): - self.client: StreamQualityClient = robot.ensure_client( - StreamQualityClient.default_service_name - ) + self.client: StreamQualityClient = robot.ensure_client(StreamQualityClient.default_service_name) self.logger = logger def set_stream_params(self, target_bitrate, refresh_interval, idr_interval, awb): @@ -414,9 +403,7 @@ class MediaLogWrapper: """ def __init__(self, robot, logger) -> None: - self.client: MediaLogClient = robot.ensure_client( - MediaLogClient.default_service_name - ) + self.client: MediaLogClient = robot.ensure_client(MediaLogClient.default_service_name) self.logger = logger def list_cameras(self) -> typing.List[Camera]: @@ -434,9 +421,7 @@ def list_logpoints(self) -> typing.List[Logpoint]: """ return self.client.list_logpoints() - def retrieve_logpoint( - self, name: str, raw: bool = False - ) -> typing.Tuple[Logpoint, DataChunk]: + def retrieve_logpoint(self, name: str, raw: bool = False) -> typing.Tuple[Logpoint, DataChunk]: """ Retrieve a logpoint from the camera @@ -477,9 +462,7 @@ def delete_logpoint(self, name: str) -> None: """ self.client.delete(logpoint=Logpoint(name=name)) - def store( - self, camera: SpotCamCamera, tag: typing.Optional[str] = None - ) -> Logpoint: + def store(self, camera: SpotCamCamera, tag: typing.Optional[str] = None) -> Logpoint: """ Take a snapshot of the data currently on the given camera and store it to a logpoint. @@ -490,9 +473,7 @@ def store( Returns: Logpoint containing information about the stored data """ - return self.client.store( - camera=Camera(name=camera.value), record_type=Logpoint.STILLIMAGE, tag=tag - ) + return self.client.store(camera=Camera(name=camera.value), record_type=Logpoint.STILLIMAGE, tag=tag) def tag(self, name: str, tag: str) -> None: """ @@ -553,7 +534,8 @@ def save_logpoint_image( path: Save the data to this directory base_filename: Use this filename as the base name for the image file raw: If true, retrieve raw data rather than processed data. Useful for IR images? - camera: If set, add the name of the camera to the output filename. The logpoint doesn't store this information + camera: If set, add the name of the camera to the output filename. The logpoint doesn't store this + information use_rgb24: If set, save the ptz image in .rgb24 format without compression. By default it is saved to png Returns: @@ -570,9 +552,7 @@ def save_logpoint_image( # Special case for 16 bit raw thermal image if logpoint.image_params.format == image_pb2.Image.PIXEL_FORMAT_GREYSCALE_U16: np_img = np.frombuffer(image, dtype=np.uint16).byteswap() - np_img = np_img.reshape( - (logpoint.image_params.height, logpoint.image_params.width, 1) - ) + np_img = np_img.reshape((logpoint.image_params.height, logpoint.image_params.width, 1)) full_path = os.path.join( save_path, self._build_filename(logpoint, base_filename, ".pgm", SpotCamCamera.IR), @@ -584,10 +564,7 @@ def save_logpoint_image( if ( logpoint.image_params.height == 4800 or logpoint.image_params.height == 2400 - or ( - logpoint.image_params.width == 640 - and logpoint.image_params.height == 512 - ) + or (logpoint.image_params.width == 640 and logpoint.image_params.height == 512) ): full_path = os.path.join( save_path, @@ -715,9 +692,7 @@ def _get_ptz_description(self, name): PtzDescription """ if name not in self.ptzs: - self.logger.warn( - f"Tried to retrieve description for ptz {name} but it does not exist." - ) + self.logger.warn(f"Tried to retrieve description for ptz {name} but it does not exist.") return None return self.ptzs[name] @@ -742,9 +717,7 @@ def _clamp_value_to_limits(self, value, limits: PtzDescription.Limits): return max(min(value, limits.max.value), limits.min.value) - def _clamp_request_to_limits( - self, ptz_name, pan, tilt, zoom - ) -> typing.Tuple[float, float, float]: + def _clamp_request_to_limits(self, ptz_name, pan, tilt, zoom) -> typing.Tuple[float, float, float]: """ Args: @@ -782,26 +755,20 @@ def set_ptz_position(self, ptz_name, pan, tilt, zoom, blocking=False): pan: Set the pan to this value in degrees tilt: Set the tilt to this value in degrees zoom: Set the zoom to this zoom level - blocking: If true, block for 3 seconds or until the ptz is within 1 degree of the requested pan and tilt values, and - 0.5 zoom levels of the requested zoom level + blocking: If true, block for 3 seconds or until the ptz is within 1 degree of the requested pan and tilt + values, and 0.5 zoom levels of the requested zoom level """ pan, tilt, zoom = self._clamp_request_to_limits(ptz_name, pan, tilt, zoom) - self.client.set_ptz_position( - self._get_ptz_description(ptz_name), pan, tilt, zoom - ) + self.client.set_ptz_position(self._get_ptz_description(ptz_name), pan, tilt, zoom) if blocking: start_time = datetime.datetime.now() - current_position = self.client.get_ptz_position( - self._get_ptz_description(ptz_name) - ) + current_position = self.client.get_ptz_position(self._get_ptz_description(ptz_name)) while not ( math.isclose(current_position.pan.value, pan, abs_tol=1) and math.isclose(current_position.tilt.value, tilt, abs_tol=1) and math.isclose(current_position.zoom.value, zoom, abs_tol=0.5) ) and datetime.datetime.now() - start_time < datetime.timedelta(seconds=3): - current_position = self.client.get_ptz_position( - self._get_ptz_description(ptz_name) - ) + current_position = self.client.get_ptz_position(self._get_ptz_description(ptz_name)) time.sleep(0.2) def get_ptz_velocity(self, ptz_name) -> PtzVelocity: @@ -827,9 +794,7 @@ def set_ptz_velocity(self, ptz_name, pan, tilt, zoom): zoom: Set the zoom to this value in zoom level per second """ # We do not clamp the velocity to the limits, as it is a rate - self.client.set_ptz_velocity( - self._get_ptz_description(ptz_name), pan, tilt, zoom - ) + self.client.set_ptz_velocity(self._get_ptz_description(ptz_name), pan, tilt, zoom) def initialise_lens(self): """ @@ -924,17 +889,13 @@ async def _process_func(self): while not self.client.audio_frame_queue.empty(): await self.client.audio_frame_queue.get() except Exception as e: - self.logger.error( - f"Image stream wrapper exception while discarding audio frames {e}" - ) + self.logger.error(f"Image stream wrapper exception while discarding audio frames {e}") self.shutdown_flag.set() class SpotCamWrapper: - def __init__( - self, hostname, username, password, logger, port: typing.Optional[int] = None - ): + def __init__(self, hostname, username, password, logger, port: typing.Optional[int] = None): self._hostname = hostname self._username = username self._password = password @@ -947,13 +908,9 @@ def __init__( self.robot = self.sdk.create_robot(self._hostname) if port is not None: self.robot.update_secure_channel_port(port) - SpotWrapper.authenticate( - self.robot, self._username, self._password, self._logger - ) + SpotWrapper.authenticate(self.robot, self._username, self._password, self._logger) - self.payload_client: PayloadClient = self.robot.ensure_client( - PayloadClient.default_service_name - ) + self.payload_client: PayloadClient = self.robot.ensure_client(PayloadClient.default_service_name) self.payload_details = None for payload in self.payload_client.list_payloads(): if payload.is_enabled and "Spot CAM" in payload.name: diff --git a/spot_wrapper/spot_arm.py b/spot_wrapper/spot_arm.py index b1ba2263..b0e0fd85 100644 --- a/spot_wrapper/spot_arm.py +++ b/spot_wrapper/spot_arm.py @@ -2,14 +2,16 @@ import time import typing -from bosdyn.api import arm_command_pb2 -from bosdyn.api import geometry_pb2 -from bosdyn.api import manipulation_api_pb2 -from bosdyn.api import robot_command_pb2 -from bosdyn.api import synchronized_command_pb2 -from bosdyn.api import trajectory_pb2 +from bosdyn.api import ( + arm_command_pb2, + geometry_pb2, + gripper_command_pb2, + manipulation_api_pb2, + robot_command_pb2, + synchronized_command_pb2, + trajectory_pb2, +) from bosdyn.client.manipulation_api_client import ManipulationApiClient -from bosdyn.api import gripper_command_pb2 from bosdyn.client.robot import Robot from bosdyn.client.robot_command import ( RobotCommandBuilder, @@ -21,7 +23,7 @@ from bosdyn.client.time_sync import TimeSyncEndpoint from bosdyn.util import seconds_to_duration -from spot_wrapper.wrapper_helpers import RobotState, ClaimAndPowerDecorator +from spot_wrapper.wrapper_helpers import ClaimAndPowerDecorator, RobotState class SpotArm: @@ -104,9 +106,7 @@ def manipulation_command(self, request: manipulation_api_pb2): ) def get_manipulation_command_feedback(self, cmd_id): - feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest( - manipulation_cmd_id=cmd_id - ) + feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest(manipulation_cmd_id=cmd_id) return self._manipulation_api_client.manipulation_api_feedback_command( manipulation_api_feedback_request=feedback_request @@ -145,9 +145,7 @@ def wait_for_arm_command_to_complete(self, cmd_id, timeout_sec=None): timeout_sec: After this time, timeout regardless of what the robot state is """ - block_until_arm_arrives( - self._robot_command_client, cmd_id=cmd_id, timeout_sec=timeout_sec - ) + block_until_arm_arrives(self._robot_command_client, cmd_id=cmd_id, timeout_sec=timeout_sec) def arm_stow(self) -> typing.Tuple[bool, str]: """ @@ -227,18 +225,10 @@ def make_arm_trajectory_command( """Helper function to create a RobotCommand from an ArmJointTrajectory. Copy from 'spot-sdk/python/examples/arm_joint_move/arm_joint_move.py'""" - joint_move_command = arm_command_pb2.ArmJointMoveCommand.Request( - trajectory=arm_joint_trajectory - ) - arm_command = arm_command_pb2.ArmCommand.Request( - arm_joint_move_command=joint_move_command - ) - sync_arm = synchronized_command_pb2.SynchronizedCommand.Request( - arm_command=arm_command - ) - arm_sync_robot_cmd = robot_command_pb2.RobotCommand( - synchronized_command=sync_arm - ) + joint_move_command = arm_command_pb2.ArmJointMoveCommand.Request(trajectory=arm_joint_trajectory) + arm_command = arm_command_pb2.ArmCommand.Request(arm_joint_move_command=joint_move_command) + sync_arm = synchronized_command_pb2.SynchronizedCommand.Request(arm_command=arm_command) + arm_sync_robot_cmd = robot_command_pb2.RobotCommand(synchronized_command=sync_arm) return RobotCommandBuilder.build_synchro_command(arm_sync_robot_cmd) def arm_joint_move(self, joint_targets) -> typing.Tuple[bool, str]: @@ -286,19 +276,15 @@ def arm_joint_move(self, joint_targets) -> typing.Tuple[bool, str]: self._logger.info(msg) return False, msg else: - trajectory_point = ( - RobotCommandBuilder.create_arm_joint_trajectory_point( - joint_targets[0], - joint_targets[1], - joint_targets[2], - joint_targets[3], - joint_targets[4], - joint_targets[5], - ) - ) - arm_joint_trajectory = arm_command_pb2.ArmJointTrajectory( - points=[trajectory_point] + trajectory_point = RobotCommandBuilder.create_arm_joint_trajectory_point( + joint_targets[0], + joint_targets[1], + joint_targets[2], + joint_targets[3], + joint_targets[4], + joint_targets[5], ) + arm_joint_trajectory = arm_command_pb2.ArmJointTrajectory(points=[trajectory_point]) arm_command = self.make_arm_trajectory_command(arm_joint_trajectory) # Send the request @@ -325,27 +311,17 @@ def force_trajectory(self, data) -> typing.Tuple[bool, str]: traj_duration = data.duration # first point on trajectory - wrench0 = self.create_wrench_from_forces_and_torques( - data.forces_pt0, data.torques_pt0 - ) + wrench0 = self.create_wrench_from_forces_and_torques(data.forces_pt0, data.torques_pt0) t0 = seconds_to_duration(0) - traj_point0 = trajectory_pb2.WrenchTrajectoryPoint( - wrench=wrench0, time_since_reference=t0 - ) + traj_point0 = trajectory_pb2.WrenchTrajectoryPoint(wrench=wrench0, time_since_reference=t0) # Second point on the trajectory - wrench1 = self.create_wrench_from_forces_and_torques( - data.forces_pt1, data.torques_pt1 - ) + wrench1 = self.create_wrench_from_forces_and_torques(data.forces_pt1, data.torques_pt1) t1 = seconds_to_duration(traj_duration) - traj_point1 = trajectory_pb2.WrenchTrajectoryPoint( - wrench=wrench1, time_since_reference=t1 - ) + traj_point1 = trajectory_pb2.WrenchTrajectoryPoint(wrench=wrench1, time_since_reference=t1) # Build the trajectory - trajectory = trajectory_pb2.WrenchTrajectory( - points=[traj_point0, traj_point1] - ) + trajectory = trajectory_pb2.WrenchTrajectory(points=[traj_point0, traj_point1]) # Build the trajectory request, putting all axes into force mode arm_cartesian_command = arm_command_pb2.ArmCartesianCommand.Request( @@ -358,17 +334,9 @@ def force_trajectory(self, data) -> typing.Tuple[bool, str]: ry_axis=arm_command_pb2.ArmCartesianCommand.Request.AXIS_MODE_FORCE, rz_axis=arm_command_pb2.ArmCartesianCommand.Request.AXIS_MODE_FORCE, ) - arm_command = arm_command_pb2.ArmCommand.Request( - arm_cartesian_command=arm_cartesian_command - ) - synchronized_command = ( - synchronized_command_pb2.SynchronizedCommand.Request( - arm_command=arm_command - ) - ) - robot_command = robot_command_pb2.RobotCommand( - synchronized_command=synchronized_command - ) + arm_command = arm_command_pb2.ArmCommand.Request(arm_cartesian_command=arm_cartesian_command) + synchronized_command = synchronized_command_pb2.SynchronizedCommand.Request(arm_command=arm_command) + robot_command = robot_command_pb2.RobotCommand(synchronized_command=synchronized_command) # Send the request cmd_id = self._robot_command_client.robot_command(robot_command) @@ -398,9 +366,7 @@ def gripper_open(self) -> typing.Tuple[bool, str]: # Command issue with RobotCommandClient cmd_id = self._robot_command_client.robot_command(command) self._logger.info("Command gripper open sent") - self.block_until_gripper_command_completes( - self._robot_command_client, cmd_id - ) + self.block_until_gripper_command_completes(self._robot_command_client, cmd_id) except Exception as e: return False, f"Exception occured while gripper was moving: {e}" @@ -426,9 +392,7 @@ def gripper_close(self) -> typing.Tuple[bool, str]: # Command issue with RobotCommandClient cmd_id = self._robot_command_client.robot_command(command) self._logger.info("Command gripper close sent") - self.block_until_gripper_command_completes( - self._robot_command_client, cmd_id - ) + self.block_until_gripper_command_completes(self._robot_command_client, cmd_id) except Exception as e: return False, f"Exception occured while gripper was moving: {e}" @@ -463,9 +427,7 @@ def gripper_angle_open(self, gripper_ang: float) -> typing.Tuple[bool, str]: # Command issue with RobotCommandClient cmd_id = self._robot_command_client.robot_command(command) self._logger.info("Command gripper open angle sent") - self.block_until_gripper_command_completes( - self._robot_command_client, cmd_id - ) + self.block_until_gripper_command_completes(self._robot_command_client, cmd_id) except Exception as e: return False, f"Exception occured while gripper was moving: {e}" @@ -510,32 +472,20 @@ def hand_pose(self, data) -> typing.Tuple[bool, str]: # Build the SE(3) pose of the desired hand position in the moving body frame. hand_pose = geometry_pb2.SE3Pose(position=position, rotation=rotation) - hand_pose_traj_point = trajectory_pb2.SE3TrajectoryPoint( - pose=hand_pose, time_since_reference=duration - ) - hand_trajectory = trajectory_pb2.SE3Trajectory( - points=[hand_pose_traj_point] - ) + hand_pose_traj_point = trajectory_pb2.SE3TrajectoryPoint(pose=hand_pose, time_since_reference=duration) + hand_trajectory = trajectory_pb2.SE3Trajectory(points=[hand_pose_traj_point]) arm_cartesian_command = arm_command_pb2.ArmCartesianCommand.Request( root_frame_name=data.frame, pose_trajectory_in_task=hand_trajectory, force_remain_near_current_joint_configuration=True, ) - arm_command = arm_command_pb2.ArmCommand.Request( - arm_cartesian_command=arm_cartesian_command - ) - synchronized_command = ( - synchronized_command_pb2.SynchronizedCommand.Request( - arm_command=arm_command - ) - ) + arm_command = arm_command_pb2.ArmCommand.Request(arm_cartesian_command=arm_cartesian_command) + synchronized_command = synchronized_command_pb2.SynchronizedCommand.Request(arm_command=arm_command) - robot_command = robot_command_pb2.RobotCommand( - synchronized_command=synchronized_command - ) + robot_command = robot_command_pb2.RobotCommand(synchronized_command=synchronized_command) - command = RobotCommandBuilder.build_synchro_command(robot_command) + RobotCommandBuilder.build_synchro_command(robot_command) # Send the request cmd_id = self._robot_command_client.robot_command(robot_command) @@ -574,9 +524,7 @@ def block_until_gripper_command_completes( while timeout_sec is None or now < end_time: feedback_resp = robot_command_client.robot_command_feedback(cmd_id) - gripper_state = ( - feedback_resp.feedback.gripper_command_feedback.claw_gripper_feedback.status - ) + gripper_state = feedback_resp.feedback.gripper_command_feedback.claw_gripper_feedback.status if gripper_state in [ gripper_command_pb2.ClawGripperCommand.Feedback.STATUS_AT_GOAL, @@ -585,10 +533,7 @@ def block_until_gripper_command_completes( # If the gripper is commanded to close, it is successful either if it reaches the goal, or if it is # applying a force. Applying a force stops the command and puts it into force control mode. return True - if ( - gripper_state - == gripper_command_pb2.ClawGripperCommand.Feedback.STATUS_UNKNOWN - ): + if gripper_state == gripper_command_pb2.ClawGripperCommand.Feedback.STATUS_UNKNOWN: return False time.sleep(0.1) @@ -619,9 +564,7 @@ def block_until_manipulation_completes( now = time.time() while timeout_sec is None or now < end_time: - feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest( - manipulation_cmd_id=cmd_id - ) + feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest(manipulation_cmd_id=cmd_id) # Send the request response = manipulation_client.manipulation_api_feedback_command( @@ -638,9 +581,7 @@ def block_until_manipulation_completes( now = time.time() return False - def grasp_3d( - self, frame: str, object_rt_frame: typing.List[float] - ) -> typing.Tuple[bool, str]: + def grasp_3d(self, frame: str, object_rt_frame: typing.List[float]) -> typing.Tuple[bool, str]: """ Attempt to grasp an object @@ -653,24 +594,18 @@ def grasp_3d( """ try: frm = str(frame) - pos = geometry_pb2.Vec3( - x=object_rt_frame[0], y=object_rt_frame[1], z=object_rt_frame[2] - ) + pos = geometry_pb2.Vec3(x=object_rt_frame[0], y=object_rt_frame[1], z=object_rt_frame[2]) grasp = manipulation_api_pb2.PickObject(frame_name=frm, object_rt_frame=pos) # Ask the robot to pick up the object - grasp_request = manipulation_api_pb2.ManipulationApiRequest( - pick_object=grasp - ) + grasp_request = manipulation_api_pb2.ManipulationApiRequest(pick_object=grasp) # Send the request cmd_response = self._manipulation_api_client.manipulation_api_command( manipulation_api_request=grasp_request ) - success = self.block_until_manipulation_completes( - self._manipulation_api_client, cmd_response.cmd_id - ) + success = self.block_until_manipulation_completes(self._manipulation_api_client, cmd_response.cmd_id) if success: msg = "Grasped successfully" diff --git a/spot_wrapper/spot_check.py b/spot_wrapper/spot_check.py index a23b2e1a..3541e1bd 100644 --- a/spot_wrapper/spot_check.py +++ b/spot_wrapper/spot_check.py @@ -4,11 +4,11 @@ from bosdyn.api import header_pb2 from bosdyn.client import robot_command -from bosdyn.client.lease import LeaseClient, LeaseWallet, Lease +from bosdyn.client.lease import Lease, LeaseClient, LeaseWallet from bosdyn.client.robot import Robot -from bosdyn.client.spot_check import SpotCheckClient, run_spot_check -from bosdyn.client.spot_check import spot_check_pb2 +from bosdyn.client.spot_check import SpotCheckClient, run_spot_check, spot_check_pb2 from google.protobuf.timestamp_pb2 import Timestamp + from spot_wrapper.wrapper_helpers import RobotState @@ -44,9 +44,7 @@ def _get_lease(self) -> Lease: self._lease = self._lease_wallet.get_lease() return self._lease - def _feedback_error_check( - self, resp: spot_check_pb2.SpotCheckFeedbackResponse - ) -> typing.Tuple[bool, str]: + def _feedback_error_check(self, resp: spot_check_pb2.SpotCheckFeedbackResponse) -> typing.Tuple[bool, str]: """Check for errors in the feedback response""" # Save results from Spot Check @@ -55,7 +53,9 @@ def _feedback_error_check( errorcode_mapping = { spot_check_pb2.SpotCheckFeedbackResponse.ERROR_UNEXPECTED_POWER_CHANGE: "Unexpected power change", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_INIT_IMU_CHECK: "Robot body is not flat on the ground", - spot_check_pb2.SpotCheckFeedbackResponse.ERROR_INIT_NOT_SITTING: "Robot body is not close to a sitting pose", + spot_check_pb2.SpotCheckFeedbackResponse.ERROR_INIT_NOT_SITTING: ( + "Robot body is not close to a sitting pose" + ), spot_check_pb2.SpotCheckFeedbackResponse.ERROR_LOADCELL_TIMEOUT: "Timeout during loadcell calibration", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_POWER_ON_FAILURE: "Error enabling motor power", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_ENDSTOP_TIMEOUT: "Timeout during endstop calibration", @@ -64,7 +64,9 @@ def _feedback_error_check( spot_check_pb2.SpotCheckFeedbackResponse.ERROR_GROUND_CHECK: "Flat ground check failed", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_POWER_OFF_FAILURE: "Robot failed to power off", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_REVERT_FAILURE: "Robot failed to revert calibration", - spot_check_pb2.SpotCheckFeedbackResponse.ERROR_FGKC_FAILURE: "Robot failed to do flat ground kinematic calibration", + spot_check_pb2.SpotCheckFeedbackResponse.ERROR_FGKC_FAILURE: ( + "Robot failed to do flat ground kinematic calibration" + ), spot_check_pb2.SpotCheckFeedbackResponse.ERROR_GRIPPER_CAL_TIMEOUT: "Timeout during gripper calibration", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_ARM_CHECK_COLLISION: "Arm motion would cause collisions", spot_check_pb2.SpotCheckFeedbackResponse.ERROR_ARM_CHECK_TIMEOUT: "Timeout during arm joint check", @@ -87,16 +89,12 @@ def _req_feedback(self) -> spot_check_pb2.SpotCheckFeedbackResponse: start_time_seconds, start_time_ns = int(time.time()), int(time.time_ns() % 1e9) req = spot_check_pb2.SpotCheckFeedbackRequest( header=header_pb2.RequestHeader( - request_timestamp=Timestamp( - seconds=start_time_seconds, nanos=start_time_ns - ), + request_timestamp=Timestamp(seconds=start_time_seconds, nanos=start_time_ns), client_name="spot-check", disable_rpc_logging=False, ) ) - resp: spot_check_pb2.SpotCheckFeedbackResponse = ( - self._spot_check_client.spot_check_feedback(req) - ) + resp: spot_check_pb2.SpotCheckFeedbackResponse = self._spot_check_client.spot_check_feedback(req) self._spot_check_resp = resp @@ -107,9 +105,7 @@ def _spot_check_cmd(self, command: spot_check_pb2.SpotCheckCommandRequest): start_time_seconds, start_time_ns = int(time.time()), int(time.time_ns() % 1e9) req = spot_check_pb2.SpotCheckCommandRequest( header=header_pb2.RequestHeader( - request_timestamp=Timestamp( - seconds=start_time_seconds, nanos=start_time_ns - ), + request_timestamp=Timestamp(seconds=start_time_seconds, nanos=start_time_ns), client_name="spot-check", disable_rpc_logging=False, ), @@ -163,9 +159,7 @@ def start_check(self) -> typing.Tuple[bool, str]: try: self._robot.power_on() if not self._robot_state.is_sitting: - robot_command.blocking_sit( - command_client=self._robot_command_client, timeout_sec=10 - ) + robot_command.blocking_sit(command_client=self._robot_command_client, timeout_sec=10) self._logger.info("Spot is sitting") else: self._logger.info("Spot is already sitting") @@ -206,9 +200,7 @@ def blocking_check( # Make sure we're powered on and sitting self._robot.power_on() if not self._robot_state.is_sitting: - robot_command.blocking_sit( - command_client=self._robot_command_client, timeout_sec=10 - ) + robot_command.blocking_sit(command_client=self._robot_command_client, timeout_sec=10) self._logger.info("Spot is sitting") else: self._logger.info("Spot is already sitting") @@ -229,9 +221,7 @@ def blocking_check( return success, status except Exception as e: - self._logger.error( - "Exception thrown during blocking spot check: {}".format(e) - ) + self._logger.error("Exception thrown during blocking spot check: {}".format(e)) return False, str(e) def get_feedback(self) -> spot_check_pb2.SpotCheckFeedbackResponse: diff --git a/spot_wrapper/spot_dance.py b/spot_wrapper/spot_dance.py index 6faf2c88..89ec983f 100644 --- a/spot_wrapper/spot_dance.py +++ b/spot_wrapper/spot_dance.py @@ -1,35 +1,31 @@ import logging -import time -import tempfile import os +import tempfile +import time +from typing import List, Tuple, Union -from bosdyn.choreography.client.choreography import ( - ChoreographyClient, -) -from bosdyn.client import ResponseError -from bosdyn.client.common import FutureWrapper -from bosdyn.client.exceptions import UnauthenticatedError -from bosdyn.client.robot import Robot -from bosdyn.choreography.client.choreography import ( - ChoreographyClient, - AnimationValidationFailedError, -) -from bosdyn.choreography.client.animation_file_to_proto import ( - convert_animation_file_to_proto, -) from bosdyn.api.spot.choreography_sequence_pb2 import ( Animation, - ChoreographyStatusResponse, ChoreographySequence, + ChoreographyStatusResponse, ExecuteChoreographyResponse, MoveParams, StartRecordingStateResponse, StopRecordingStateResponse, UploadAnimatedMoveResponse, - UploadChoreographyResponse, ) +from bosdyn.choreography.client.animation_file_to_proto import ( + convert_animation_file_to_proto, +) +from bosdyn.choreography.client.choreography import ( + AnimationValidationFailedError, + ChoreographyClient, +) +from bosdyn.client import ResponseError +from bosdyn.client.common import FutureWrapper +from bosdyn.client.exceptions import UnauthenticatedError +from bosdyn.client.robot import Robot from google.protobuf import text_format -from typing import Tuple, List, Union class SpotDance: @@ -43,9 +39,7 @@ def __init__( self._choreography_client = choreography_client self._logger = logger - def upload_animation( - self, animation_name: str, animation_file_content: str - ) -> Tuple[bool, str]: + def upload_animation(self, animation_name: str, animation_file_content: str) -> Tuple[bool, str]: """uploads an animation file""" # Load the animation file by saving the content to a temp file with tempfile.TemporaryDirectory() as temp_dir: @@ -66,18 +60,17 @@ def upload_animation_proto(self, animation: Animation) -> Tuple[bool, str]: result_message = "" try: self._logger.info(f"Uploading the name {animation.name}") - upload_response = self._choreography_client.upload_animated_move( - animation, animation.name - ) + upload_response = self._choreography_client.upload_animated_move(animation, animation.name) result = upload_response.status == UploadAnimatedMoveResponse.STATUS_OK if result: result_message = "Successfully uploaded" if upload_response.warnings: - result_message += ( - f" with warnings from validator {upload_response.warnings}" - ) + result_message += f" with warnings from validator {upload_response.warnings}" else: - result_message = f"Failed to upload animation with status {upload_response.status} and warnings: {upload_response.warnings}" + result_message = ( + f"Failed to upload animation with status {upload_response.status} and warnings:" + f" {upload_response.warnings}" + ) except AnimationValidationFailedError as e: result_message = f"Failed to upload animation: {e}" if e.response.warnings: @@ -129,9 +122,7 @@ def get_choreography_status(self) -> Tuple[bool, str, ChoreographyStatusResponse response, ) - def start_recording_state( - self, duration_seconds: float - ) -> Tuple[bool, str, StartRecordingStateResponse]: + def start_recording_state(self, duration_seconds: float) -> Tuple[bool, str, StartRecordingStateResponse]: """start recording robot motion as choreography""" try: status = self._choreography_client.start_recording_state(duration_seconds) @@ -162,9 +153,7 @@ def choreography_log_to_animation_file( ) -> Tuple[bool, str, str]: """save a choreography log to a file as an animation""" try: - file_name = self._choreography_client.choreography_log_to_animation_file( - name, fpath, has_arm, **kwargs - ) + file_name = self._choreography_client.choreography_log_to_animation_file(name, fpath, has_arm, **kwargs) return True, "success", file_name except Exception as e: return ( @@ -190,12 +179,8 @@ def stop_choreography(self) -> Tuple[bool, str]: upload_res, upload_msg = self.upload_choreography(template_sequence) # Try ro execute regardless of upload success - may have uploaded successfully in the past - execute_res, execute_msg = self.execute_choreography_by_name( - CHOREO_NAME, start_slice=0 - ) - combined_message = ( - f"Stop upload msg: {upload_msg}\n, Stop execute message: {execute_msg}" - ) + execute_res, execute_msg = self.execute_choreography_by_name(CHOREO_NAME, start_slice=0) + combined_message = f"Stop upload msg: {upload_msg}\n, Stop execute message: {execute_msg}" return execute_res, combined_message def execute_choreography_by_name( @@ -220,24 +205,18 @@ def execute_choreography_by_name( client_start_time=client_start_time, choreography_starting_slice=start_slice, ) - result = ( - execute_response.status == ExecuteChoreographyResponse.STATUS_OK - ) + result = execute_response.status == ExecuteChoreographyResponse.STATUS_OK msg = "Success" if result else "Failure" return (result, msg) except Exception as e: error_msg = f"Exception: {e}" return (False, error_msg) - def upload_choreography( - self, choreography_sequence: ChoreographySequence - ) -> Tuple[bool, str]: + def upload_choreography(self, choreography_sequence: ChoreographySequence) -> Tuple[bool, str]: """Upload choreography sequence for later playback""" try: - upload_response = self._choreography_client.upload_choreography( - choreography_sequence, non_strict_parsing=True - ) - except UnauthenticatedError as err: + self._choreography_client.upload_choreography(choreography_sequence, non_strict_parsing=True) + except UnauthenticatedError: error_msg = ( "The robot license must contain 'choreography' permissions to upload and execute dances. " "Please contact Boston Dynamics Support to get the appropriate license file. " @@ -268,10 +247,7 @@ def execute_dance(self, data: Union[ChoreographySequence, str]) -> Tuple[bool, s choreography = ChoreographySequence() text_format.Merge(data, choreography) except Exception as execp: - error_msg = ( - "Failed to read choreography from file. Raised exception: " - + str(execp) - ) + error_msg = "Failed to read choreography from file. Raised exception: " + str(execp) return False, error_msg (result, message) = self.upload_choreography(choreography) @@ -283,9 +259,7 @@ def execute_dance(self, data: Union[ChoreographySequence, str]) -> Tuple[bool, s # Setup common response in case of exception result_msg = f"Choreography uploaded with message: {message} \n" self._robot.power_on() - (result, message) = self.execute_choreography_by_name( - choreography.name, start_slice=0, use_async=False - ) + (result, message) = self.execute_choreography_by_name(choreography.name, start_slice=0, use_async=False) if result: result_msg += "Success: Dance Execution" @@ -295,9 +269,7 @@ def execute_dance(self, data: Union[ChoreographySequence, str]) -> Tuple[bool, s total_choreography_slices = 0 for move in choreography.moves: total_choreography_slices += move.requested_slices - estimated_time_seconds = ( - total_choreography_slices / choreography.slices_per_minute * 60.0 - ) + estimated_time_seconds = total_choreography_slices / choreography.slices_per_minute * 60.0 time.sleep(estimated_time_seconds) return result, result_msg except Exception as e: diff --git a/spot_wrapper/spot_docking.py b/spot_wrapper/spot_docking.py index 085152cf..3e5aab46 100644 --- a/spot_wrapper/spot_docking.py +++ b/spot_wrapper/spot_docking.py @@ -7,9 +7,9 @@ from bosdyn.client.robot import Robot from spot_wrapper.wrapper_helpers import ( - RobotState, - RobotCommandData, ClaimAndPowerDecorator, + RobotCommandData, + RobotState, ) @@ -37,9 +37,7 @@ def __init__( self._claim_and_power_decorator = claim_and_power_decorator # Decorate the functions so that they take the lease. Dock function needs to power on because it might have # to move the robot, the undock - self._claim_and_power_decorator.decorate_functions( - self, decorated_funcs=[self.dock, self.undock] - ) + self._claim_and_power_decorator.decorate_functions(self, decorated_funcs=[self.dock, self.undock]) def dock(self, dock_id: int) -> typing.Tuple[bool, str]: """Dock the robot to the docking station with fiducial ID [dock_id].""" @@ -47,9 +45,7 @@ def dock(self, dock_id: int) -> typing.Tuple[bool, str]: # Make sure we're powered on and standing self._robot.power_on() if not self._robot_state.is_standing: - robot_command.blocking_stand( - command_client=self._robot_command_client, timeout_sec=10 - ) + robot_command.blocking_stand(command_client=self._robot_command_client, timeout_sec=10) self._logger.info("Spot is standing") else: self._logger.info("Spot is already standing") diff --git a/spot_wrapper/spot_eap.py b/spot_wrapper/spot_eap.py index 476447db..aa2739ad 100644 --- a/spot_wrapper/spot_eap.py +++ b/spot_wrapper/spot_eap.py @@ -44,9 +44,7 @@ def _start_query(self) -> typing.Optional[FutureWrapper]: raise TypeError("Point cloud requests must be a list.") if self._callback is not None and len(self._point_cloud_requests) > 0: - callback_future = self._client.get_point_cloud_async( - self._point_cloud_requests - ) + callback_future = self._client.get_point_cloud_async(self._point_cloud_requests) callback_future.add_done_callback(self._callback) return callback_future diff --git a/spot_wrapper/spot_graph_nav.py b/spot_wrapper/spot_graph_nav.py index 72fa0f78..4b39080f 100644 --- a/spot_wrapper/spot_graph_nav.py +++ b/spot_wrapper/spot_graph_nav.py @@ -4,13 +4,10 @@ import time import typing -from bosdyn.api.graph_nav import graph_nav_pb2 -from bosdyn.api.graph_nav import map_pb2 -from bosdyn.api.graph_nav import map_processing_pb2 -from bosdyn.api.graph_nav import nav_pb2 +from bosdyn.api.graph_nav import graph_nav_pb2, map_pb2, map_processing_pb2, nav_pb2 from bosdyn.client.frame_helpers import get_odom_tform_body from bosdyn.client.graph_nav import GraphNavClient -from bosdyn.client.lease import LeaseClient, LeaseWallet, LeaseKeepAlive, Lease +from bosdyn.client.lease import Lease, LeaseClient, LeaseKeepAlive, LeaseWallet from bosdyn.client.map_processing import MapProcessingServiceClient from bosdyn.client.robot import Robot from bosdyn.client.robot_state import RobotStateClient @@ -61,9 +58,7 @@ def _init_current_graph_nav_state(self): self._current_waypoint_snapshots = {} # maps id to waypoint snapshot self._current_edge_snapshots = {} # maps id to edge snapshot self._current_annotation_name_to_wp_id = {} - self._current_anchored_world_objects = ( - {} - ) # maps object id to a (wo, waypoint, fiducial) + self._current_anchored_world_objects = {} # maps object id to a (wo, waypoint, fiducial) self._current_anchors = {} # maps anchor id to anchor def list_graph(self) -> typing.List[str]: @@ -73,12 +68,7 @@ def list_graph(self) -> typing.List[str]: """ ids, eds = self._list_graph_waypoint_and_edge_ids() - return [ - v - for k, v in sorted( - ids.items(), key=lambda id: int(id[0].replace("waypoint_", "")) - ) - ] + return [v for k, v in sorted(ids.items(), key=lambda id: int(id[0].replace("waypoint_", "")))] def navigate_initial_localization( self, @@ -112,9 +102,7 @@ def navigate_initial_localization( self._upload_graph_and_snapshots(upload_filepath) else: self._download_current_graph() - self._logger.info( - "Re-using existing graph on robot. Check that the correct graph is loaded!" - ) + self._logger.info("Re-using existing graph on robot. Check that the correct graph is loaded!") if initial_localization_fiducial: self.set_initial_localization_fiducial() if initial_localization_waypoint: @@ -152,9 +140,7 @@ def download_navigation_graph(self, download_path: str) -> typing.List[str]: self._download_full_graph() return self.list_graph() - def navigation_close_loops( - self, close_fiducial_loops: bool, close_odometry_loops: bool - ) -> typing.Tuple[bool, str]: + def navigation_close_loops(self, close_fiducial_loops: bool, close_odometry_loops: bool) -> typing.Tuple[bool, str]: return self._auto_close_loops(close_fiducial_loops, close_odometry_loops) def optmize_anchoring(self) -> typing.Tuple[bool, str]: @@ -171,9 +157,7 @@ def _write_bytes_while_download(self, filepath: str, data: bytes): f.write(data) f.close() - def _download_graph_and_snapshots( - self, download_path: str - ) -> typing.Tuple[bool, str]: + def _download_graph_and_snapshots(self, download_path: str) -> typing.Tuple[bool, str]: """Download the graph and snapshots from the robot. Args: download_path (str): Directory where graph and snapshots are downloaded from robot. @@ -184,19 +168,13 @@ def _download_graph_and_snapshots( if graph is None: return False, "Failed to download the graph." graph_bytes = graph.SerializeToString() - self._write_bytes_while_download( - os.path.join(download_path, "graph"), graph_bytes - ) + self._write_bytes_while_download(os.path.join(download_path, "graph"), graph_bytes) # Download the waypoint and edge snapshots. for waypoint in graph.waypoints: try: - waypoint_snapshot = self._graph_nav_client.download_waypoint_snapshot( - waypoint.snapshot_id - ) + waypoint_snapshot = self._graph_nav_client.download_waypoint_snapshot(waypoint.snapshot_id) except Exception: - self._logger.warning( - "Failed to download waypoint snapshot: %s", waypoint.snapshot_id - ) + self._logger.warning("Failed to download waypoint snapshot: %s", waypoint.snapshot_id) continue self._write_bytes_while_download( os.path.join(download_path, "waypoint_snapshots", waypoint.snapshot_id), @@ -204,13 +182,9 @@ def _download_graph_and_snapshots( ) for edge in graph.edges: try: - edge_snapshot = self._graph_nav_client.download_edge_snapshot( - edge.snapshot_id - ) + edge_snapshot = self._graph_nav_client.download_edge_snapshot(edge.snapshot_id) except Exception: - self._logger.warning( - "Failed to download edge snapshot: %s", edge.snapshot_id - ) + self._logger.warning("Failed to download edge snapshot: %s", edge.snapshot_id) continue self._write_bytes_while_download( os.path.join(download_path, "edge_snapshots", edge.snapshot_id), @@ -233,7 +207,8 @@ def clear_graph(self) -> typing.Tuple[bool, str]: def upload_graph(self, upload_path: str) -> typing.Tuple[bool, str]: """Upload the specified graph and snapshots from local to a robot. - While this method, if there are snapshots already in the robot, they will be loaded from the robot's disk without uploading. + While this method, if there are snapshots already in the robot, they will be loaded from the robot's disk + without uploading. Graph and snapshots to be uploaded should be placed like Directory specified with upload_path arg | @@ -266,9 +241,7 @@ def download_graph(self, download_path: str) -> typing.Tuple[bool, str]: Returns: (bool, str) tuple indicating whether the command was successfully sent, and a message """ try: - success, message = self._download_graph_and_snapshots( - download_path=download_path - ) + success, message = self._download_graph_and_snapshots(download_path=download_path) return success, message except Exception as e: return ( @@ -276,7 +249,9 @@ def download_graph(self, download_path: str) -> typing.Tuple[bool, str]: f"Got an error during downloading graph and snapshots from the robot: {e}", ) - ## Copied from https://github.com/boston-dynamics/spot-sdk/blob/master/python/examples/graph_nav_command_line/recording_command_line.py and https://github.com/boston-dynamics/spot-sdk/blob/master/python/examples/graph_nav_command_line/graph_nav_command_line.py with minor modifications + # Copied from https://github.com/boston-dynamics/spot-sdk/blob/master/python/examples/graph_nav_command_line/recording_command_line.py + # and https://github.com/boston-dynamics/spot-sdk/blob/master/python/examples/graph_nav_command_line/graph_nav_command_line.py + # with minor modifications # Copyright (c) 2020 Boston Dynamics, Inc. All rights reserved. # # Downloading, reproducing, distributing or otherwise using the SDK Software @@ -286,19 +261,13 @@ def _get_localization_state(self, *args): """Get the current localization and state of the robot.""" state = self._graph_nav_client.get_localization_state() self._logger.info(f"Got localization: \n{str(state.localization)}") - odom_tform_body = get_odom_tform_body( - state.robot_kinematics.transforms_snapshot - ) - self._logger.info( - f"Got robot state in kinematic odometry frame: \n{str(odom_tform_body)}" - ) + odom_tform_body = get_odom_tform_body(state.robot_kinematics.transforms_snapshot) + self._logger.info(f"Got robot state in kinematic odometry frame: \n{str(odom_tform_body)}") def set_initial_localization_fiducial(self, *args): """Trigger localization when near a fiducial.""" robot_state = self._robot_state_client.get_robot_state() - current_odom_tform_body = get_odom_tform_body( - robot_state.kinematic_state.transforms_snapshot - ).to_proto() + current_odom_tform_body = get_odom_tform_body(robot_state.kinematic_state.transforms_snapshot).to_proto() # Create an empty instance for initial localization since we are asking it to localize # based on the nearest fiducial. localization = nav_pb2.Localization() @@ -325,9 +294,7 @@ def set_initial_localization_waypoint(self, *args): return robot_state = self._robot_state_client.get_robot_state() - current_odom_tform_body = get_odom_tform_body( - robot_state.kinematic_state.transforms_snapshot - ).to_proto() + current_odom_tform_body = get_odom_tform_body(robot_state.kinematic_state.transforms_snapshot).to_proto() # Create an initial localization to the specified waypoint as the identity. localization = nav_pb2.Localization() localization.waypoint_id = destination_waypoint @@ -357,9 +324,7 @@ def _download_full_graph(self, *args): return self._write_full_graph(graph) self._logger.info( - "Graph downloaded with {} waypoints and {} edges".format( - len(graph.waypoints), len(graph.edges) - ) + "Graph downloaded with {} waypoints and {} edges".format(len(graph.waypoints), len(graph.edges)) ) # Download the waypoint and edge snapshots. self._download_and_write_waypoint_snapshots(graph.waypoints) @@ -377,14 +342,10 @@ def _download_and_write_waypoint_snapshots(self, waypoints): if len(waypoint.snapshot_id) == 0: continue try: - waypoint_snapshot = self._graph_nav_client.download_waypoint_snapshot( - waypoint.snapshot_id - ) + waypoint_snapshot = self._graph_nav_client.download_waypoint_snapshot(waypoint.snapshot_id) except Exception: # Failure in downloading waypoint snapshot. Continue to next snapshot. - self._logger.error( - "Failed to download waypoint snapshot: " + waypoint.snapshot_id - ) + self._logger.error("Failed to download waypoint snapshot: " + waypoint.snapshot_id) continue self._write_bytes( self._download_filepath + "/waypoint_snapshots", @@ -407,14 +368,10 @@ def _download_and_write_edge_snapshots(self, edges): continue num_to_download += 1 try: - edge_snapshot = self._graph_nav_client.download_edge_snapshot( - edge.snapshot_id - ) + edge_snapshot = self._graph_nav_client.download_edge_snapshot(edge.snapshot_id) except Exception: # Failure in downloading edge snapshot. Continue to next snapshot. - self._logger.error( - "Failed to download edge snapshot: " + edge.snapshot_id - ) + self._logger.error("Failed to download edge snapshot: " + edge.snapshot_id) continue self._write_bytes( self._download_filepath + "/edge_snapshots", @@ -423,9 +380,7 @@ def _download_and_write_edge_snapshots(self, edges): ) num_edge_snapshots_downloaded += 1 self._logger.info( - "Downloaded {} of the total {} edge snapshots.".format( - num_edge_snapshots_downloaded, num_to_download - ) + "Downloaded {} of the total {} edge snapshots.".format(num_edge_snapshots_downloaded, num_to_download) ) def _write_bytes(self, filepath: str, filename: str, data): @@ -441,9 +396,7 @@ def _list_graph_waypoint_and_edge_ids(self, *args): # Download current graph graph = self._download_current_graph() - localization_id = ( - self._graph_nav_client.get_localization_state().localization.waypoint_id - ) + localization_id = self._graph_nav_client.get_localization_state().localization.waypoint_id # Update and print waypoints and edges ( @@ -470,17 +423,13 @@ def _upload_graph_and_snapshots(self, upload_filepath: str): # Load the waypoint snapshots from disk. if len(waypoint.snapshot_id) == 0: continue - waypoint_filepath = os.path.join( - upload_filepath, "waypoint_snapshots", waypoint.snapshot_id - ) + waypoint_filepath = os.path.join(upload_filepath, "waypoint_snapshots", waypoint.snapshot_id) if not os.path.exists(waypoint_filepath): continue with open(waypoint_filepath, "rb") as snapshot_file: waypoint_snapshot = map_pb2.WaypointSnapshot() waypoint_snapshot.ParseFromString(snapshot_file.read()) - self._current_waypoint_snapshots[ - waypoint_snapshot.id - ] = waypoint_snapshot + self._current_waypoint_snapshots[waypoint_snapshot.id] = waypoint_snapshot for fiducial in waypoint_snapshot.objects: if not fiducial.HasField("apriltag_properties"): @@ -502,9 +451,7 @@ def _upload_graph_and_snapshots(self, upload_filepath: str): # Load the edge snapshots from disk. if len(edge.snapshot_id) == 0: continue - edge_filepath = os.path.join( - upload_filepath, "edge_snapshots", edge.snapshot_id - ) + edge_filepath = os.path.join(upload_filepath, "edge_snapshots", edge.snapshot_id) if not os.path.exists(edge_filepath): continue with open(edge_filepath, "rb") as snapshot_file: @@ -522,13 +469,9 @@ def _upload_graph_and_snapshots(self, upload_filepath: str): ) return - response = self._graph_nav_client.upload_graph( - lease=self._lease.lease_proto, graph=self._current_graph - ) + response = self._graph_nav_client.upload_graph(lease=self._lease.lease_proto, graph=self._current_graph) # Upload the snapshots to the robot. - for index, waypoint_snapshot_id in enumerate( - response.unknown_waypoint_snapshot_ids - ): + for index, waypoint_snapshot_id in enumerate(response.unknown_waypoint_snapshot_ids): waypoint_snapshot = self._current_waypoint_snapshots[waypoint_snapshot_id] self._graph_nav_client.upload_waypoint_snapshot(waypoint_snapshot) self._logger.info( @@ -568,9 +511,7 @@ def _navigate_to(self, waypoint_id: str) -> typing.Tuple[bool, str]: self._logger, ) if not destination_waypoint: - self._logger.error( - "Failed to find the appropriate unique waypoint id for the navigation command." - ) + self._logger.error("Failed to find the appropriate unique waypoint id for the navigation command.") return ( False, "Failed to find the appropriate unique waypoint id for the navigation command.", @@ -587,9 +528,7 @@ def _navigate_to(self, waypoint_id: str) -> typing.Tuple[bool, str]: while not is_finished: # Issue the navigation command about twice a second such that it is easy to terminate the # navigation command (with estop or killing the program). - nav_to_cmd_id = self._graph_nav_client.navigate_to( - destination_waypoint, 1.0, leases=[sublease.lease_proto] - ) + nav_to_cmd_id = self._graph_nav_client.navigate_to(destination_waypoint, 1.0, leases=[sublease.lease_proto]) time.sleep(0.5) # Sleep for half a second to allow for command execution. # Poll the robot for feedback to determine if the navigation command is complete. is_finished = self._check_success(nav_to_cmd_id) @@ -598,10 +537,7 @@ def _navigate_to(self, waypoint_id: str) -> typing.Tuple[bool, str]: self._lease_keepalive = LeaseKeepAlive(self._lease_client) status = self._graph_nav_client.navigation_feedback(nav_to_cmd_id) - if ( - status.status - == graph_nav_pb2.NavigationFeedbackResponse.STATUS_REACHED_GOAL - ): + if status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_REACHED_GOAL: return True, "Successfully completed the navigation commands!" elif status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_LOST: return ( @@ -613,17 +549,12 @@ def _navigate_to(self, waypoint_id: str) -> typing.Tuple[bool, str]: False, "Robot got stuck when navigating the route, the robot will now sit down.", ) - elif ( - status.status - == graph_nav_pb2.NavigationFeedbackResponse.STATUS_ROBOT_IMPAIRED - ): + elif status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_ROBOT_IMPAIRED: return False, "Robot is impaired." else: return False, "Navigation command is not complete yet." - def _navigate_route( - self, waypoint_ids: typing.List[str] - ) -> typing.Tuple[bool, str]: + def _navigate_route(self, waypoint_ids: typing.List[str]) -> typing.Tuple[bool, str]: """Navigate through a specific route of waypoints. Note that each waypoint must have an edge between them, aka be adjacent. """ @@ -635,9 +566,7 @@ def _navigate_route( self._logger, ) if not waypoint_ids[i]: - self._logger.error( - "navigate_route: Failed to find the unique waypoint id." - ) + self._logger.error("navigate_route: Failed to find the unique waypoint id.") return False, "Failed to find the unique waypoint id." edge_ids_list = [] @@ -650,9 +579,7 @@ def _navigate_route( if edge_id is not None: edge_ids_list.append(edge_id) else: - self._logger.error( - f"Failed to find an edge between waypoints: {start_wp} and {end_wp}" - ) + self._logger.error(f"Failed to find an edge between waypoints: {start_wp} and {end_wp}") return ( False, f"Failed to find an edge between waypoints: {start_wp} and {end_wp}", @@ -694,26 +621,16 @@ def _check_success(self, command_id: int = -1) -> bool: # No command, so we have not status to check. return False status = self._graph_nav_client.navigation_feedback(command_id) - if ( - status.status - == graph_nav_pb2.NavigationFeedbackResponse.STATUS_REACHED_GOAL - ): + if status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_REACHED_GOAL: # Successfully completed the navigation commands! return True elif status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_LOST: - self._logger.error( - "Robot got lost when navigating the route, the robot will now sit down." - ) + self._logger.error("Robot got lost when navigating the route, the robot will now sit down.") return True elif status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_STUCK: - self._logger.error( - "Robot got stuck when navigating the route, the robot will now sit down." - ) + self._logger.error("Robot got stuck when navigating the route, the robot will now sit down.") return True - elif ( - status.status - == graph_nav_pb2.NavigationFeedbackResponse.STATUS_ROBOT_IMPAIRED - ): + elif status.status == graph_nav_pb2.NavigationFeedbackResponse.STATUS_ROBOT_IMPAIRED: self._logger.error("Robot is impaired.") return True else: @@ -732,69 +649,44 @@ def _match_edge( for edge_from_id in current_edges[edge_to_id]: if (waypoint1 == edge_to_id) and (waypoint2 == edge_from_id): # This edge matches the pair of waypoints! Add it the edge list and continue. - return map_pb2.Edge.Id( - from_waypoint=waypoint2, to_waypoint=waypoint1 - ) + return map_pb2.Edge.Id(from_waypoint=waypoint2, to_waypoint=waypoint1) elif (waypoint2 == edge_to_id) and (waypoint1 == edge_from_id): # This edge matches the pair of waypoints! Add it the edge list and continue. - return map_pb2.Edge.Id( - from_waypoint=waypoint1, to_waypoint=waypoint2 - ) + return map_pb2.Edge.Id(from_waypoint=waypoint1, to_waypoint=waypoint2) return None - def _auto_close_loops( - self, close_fiducial_loops: bool, close_odometry_loops: bool, *args - ): + def _auto_close_loops(self, close_fiducial_loops: bool, close_odometry_loops: bool, *args): """Automatically find and close all loops in the graph.""" - response: map_processing_pb2.ProcessTopologyResponse = ( - self._map_processing_client.process_topology( - params=map_processing_pb2.ProcessTopologyRequest.Params( - do_fiducial_loop_closure=wrappers_pb2.BoolValue( - value=close_fiducial_loops - ), - do_odometry_loop_closure=wrappers_pb2.BoolValue( - value=close_odometry_loops - ), - ), - modify_map_on_server=True, - ) - ) - self._logger.info( - "Created {} new edge(s).".format(len(response.new_subgraph.edges)) + response: map_processing_pb2.ProcessTopologyResponse = self._map_processing_client.process_topology( + params=map_processing_pb2.ProcessTopologyRequest.Params( + do_fiducial_loop_closure=wrappers_pb2.BoolValue(value=close_fiducial_loops), + do_odometry_loop_closure=wrappers_pb2.BoolValue(value=close_odometry_loops), + ), + modify_map_on_server=True, ) + self._logger.info("Created {} new edge(s).".format(len(response.new_subgraph.edges))) if response.status == map_processing_pb2.ProcessTopologyResponse.STATUS_OK: return True, "Successfully closed loops." - elif ( - response.status - == map_processing_pb2.ProcessTopologyResponse.STATUS_MISSING_WAYPOINT_SNAPSHOTS - ): + elif response.status == map_processing_pb2.ProcessTopologyResponse.STATUS_MISSING_WAYPOINT_SNAPSHOTS: return False, "Missing waypoint snapshots." - elif ( - response.status - == map_processing_pb2.ProcessTopologyResponse.STATUS_INVALID_GRAPH - ): + elif response.status == map_processing_pb2.ProcessTopologyResponse.STATUS_INVALID_GRAPH: return False, "Invalid graph." - elif ( - response.status - == map_processing_pb2.ProcessTopologyResponse.STATUS_MAP_MODIFIED_DURING_PROCESSING - ): + elif response.status == map_processing_pb2.ProcessTopologyResponse.STATUS_MAP_MODIFIED_DURING_PROCESSING: return False, "Map modified during processing." else: return False, "Unknown error during map processing." def _optimize_anchoring(self, *args): - """Call anchoring optimization on the server, producing a globally optimal reference frame for waypoints to be expressed in.""" - response: map_processing_pb2.ProcessAnchoringResponse = ( - self._map_processing_client.process_anchoring( - params=map_processing_pb2.ProcessAnchoringRequest.Params(), - modify_anchoring_on_server=True, - stream_intermediate_results=False, - ) + """Call anchoring optimization on the server, producing a globally optimal reference frame for waypoints to be + expressed in. + """ + response: map_processing_pb2.ProcessAnchoringResponse = self._map_processing_client.process_anchoring( + params=map_processing_pb2.ProcessAnchoringRequest.Params(), + modify_anchoring_on_server=True, + stream_intermediate_results=False, ) if response.status == map_processing_pb2.ProcessAnchoringResponse.STATUS_OK: - self._logger.info( - "Optimized anchoring after {} iteration(s).".format(response.iteration) - ) + self._logger.info("Optimized anchoring after {} iteration(s).".format(response.iteration)) return True, "Successfully optimized anchoring." else: self._logger.error("Error optimizing {}".format(response)) @@ -896,17 +788,13 @@ def _update_waypoints_and_edges( # Determine the timestamp that this waypoint was created at. timestamp = -1.0 try: - timestamp = ( - waypoint.annotations.creation_time.seconds - + waypoint.annotations.creation_time.nanos / 1e9 - ) - except: + timestamp = waypoint.annotations.creation_time.seconds + waypoint.annotations.creation_time.nanos / 1e9 + except Exception as e: # Must be operating on an older graph nav map, since the creation_time is not # available within the waypoint annotations message. + logger.info(f"Unable process waypoint, ignoring. Exception: {e}") pass - waypoint_to_timestamp.append( - (waypoint.id, timestamp, waypoint.annotations.name) - ) + waypoint_to_timestamp.append((waypoint.id, timestamp, waypoint.annotations.name)) # Determine how many waypoints have the same short code. short_code = self._id_to_short_code(waypoint.id) @@ -928,17 +816,13 @@ def _update_waypoints_and_edges( # Sort the set of waypoints by their creation timestamp. If the creation timestamp is unavailable, # fallback to sorting by annotation name. - waypoint_to_timestamp = sorted( - waypoint_to_timestamp, key=lambda x: (x[1], x[2]) - ) + waypoint_to_timestamp = sorted(waypoint_to_timestamp, key=lambda x: (x[1], x[2])) # Print out the waypoints name, id, and short code in a ordered sorted by the timestamp from # when the waypoint was created. logger.info("%d waypoints:" % len(graph.waypoints)) for waypoint in waypoint_to_timestamp: - self._pretty_print_waypoints( - waypoint[0], waypoint[2], short_code_to_count, localization_id, logger - ) + self._pretty_print_waypoints(waypoint[0], waypoint[2], short_code_to_count, localization_id, logger) for edge in graph.edges: if edge.id.to_waypoint in edges: @@ -946,8 +830,6 @@ def _update_waypoints_and_edges( edges[edge.id.to_waypoint].append(edge.id.from_waypoint) else: edges[edge.id.to_waypoint] = [edge.id.from_waypoint] - logger.info( - f"(Edge) from waypoint id: {edge.id.from_waypoint} and to waypoint id: {edge.id.to_waypoint}" - ) + logger.info(f"(Edge) from waypoint id: {edge.id.from_waypoint} and to waypoint id: {edge.id.to_waypoint}") return name_to_id, edges diff --git a/spot_wrapper/spot_images.py b/spot_wrapper/spot_images.py index c9a15362..827af79b 100644 --- a/spot_wrapper/spot_images.py +++ b/spot_wrapper/spot_images.py @@ -34,12 +34,8 @@ "left_depth_in_visual_frame", "back_depth_in_visual_frame", ] -ImageBundle = namedtuple( - "ImageBundle", ["frontleft", "frontright", "left", "right", "back"] -) -ImageWithHandBundle = namedtuple( - "ImageBundle", ["frontleft", "frontright", "left", "right", "back", "hand"] -) +ImageBundle = namedtuple("ImageBundle", ["frontleft", "frontright", "left", "right", "back"]) +ImageWithHandBundle = namedtuple("ImageBundle", ["frontleft", "frontright", "left", "right", "back", "hand"]) IMAGE_SOURCES_BY_CAMERA = { "frontleft": { @@ -238,25 +234,19 @@ def __init__( else: quality = self._image_quality.robot_image_quality elif camera != "hand": - self._logger.info( - f"Switching {camera}:{image_type} to greyscale image format." - ) + self._logger.info(f"Switching {camera}:{image_type} to greyscale image format.") pixel_format = image_pb2.Image.PIXEL_FORMAT_GREYSCALE_U8 quality = self._image_quality.robot_image_quality source = IMAGE_SOURCES_BY_CAMERA[camera][image_type] - self._image_requests_by_camera[camera][ - image_type - ] = build_image_request( + self._image_requests_by_camera[camera][image_type] = build_image_request( source, image_format=image_format, pixel_format=pixel_format, quality_percent=quality, ) - def get_rgb_image( - self, image_source: str - ) -> typing.Optional[image_pb2.ImageResponse]: + def get_rgb_image(self, image_source: str) -> typing.Optional[image_pb2.ImageResponse]: """ Args: @@ -406,22 +396,16 @@ def get_images_by_cameras( cameras_specified = set() for item in camera_sources: if item.camera_name in cameras_specified: - self._logger.error( - f"Duplicated camera source for camera {item.camera_name}" - ) + self._logger.error(f"Duplicated camera source for camera {item.camera_name}") return None image_types = item.image_types if image_types is None: image_types = IMAGE_TYPES for image_type in image_types: try: - image_requests.append( - self._image_requests_by_camera[item.camera_name][image_type] - ) + image_requests.append(self._image_requests_by_camera[item.camera_name][image_type]) except KeyError: - self._logger.error( - f"Unexpected camera name '{item.camera_name}' or image type '{image_type}'" - ) + self._logger.error(f"Unexpected camera name '{item.camera_name}' or image type '{image_type}'") return None source_types.append((item.camera_name, image_type)) cameras_specified.add(item.camera_name) @@ -431,8 +415,7 @@ def get_images_by_cameras( image_responses = self._image_client.get_image(image_requests) except UnsupportedPixelFormatRequestedError: self._logger.error( - "UnsupportedPixelFormatRequestedError. " - "Likely pixel_format is set wrong for some image request" + "UnsupportedPixelFormatRequestedError. Likely pixel_format is set wrong for some image request" ) return None @@ -455,9 +438,7 @@ def set_gripper_camera_params( raise Exception("Gripper camera is not available") else: self._logger.info("Setting Gripper Camera Parameters") - response = self._gripper_cam_param_client.set_camera_params( - camera_param_request - ) + response = self._gripper_cam_param_client.set_camera_params(camera_param_request) return response def get_gripper_camera_params( @@ -468,7 +449,5 @@ def get_gripper_camera_params( raise Exception("Gripper camera is not available") else: self._logger.info("Getting Gripper Camera Parameters") - response = self._gripper_cam_param_client.get_camera_params( - camera_get_param_request - ) + response = self._gripper_cam_param_client.get_camera_params(camera_get_param_request) return response diff --git a/spot_wrapper/spot_world_objects.py b/spot_wrapper/spot_world_objects.py index ef0ee324..51c7893d 100644 --- a/spot_wrapper/spot_world_objects.py +++ b/spot_wrapper/spot_world_objects.py @@ -1,4 +1,5 @@ from __future__ import annotations + import logging import typing @@ -28,9 +29,7 @@ def __init__( rate: Rate (Hz) to trigger the query callback: Callback function to call when the results of the query are available """ - super().__init__( - "world-objects", client, logger, period_sec=1.0 / max(rate, 1.0) - ) + super().__init__("world-objects", client, logger, period_sec=1.0 / max(rate, 1.0)) self._callback = None if rate > 0.0: self._callback = callback @@ -92,6 +91,4 @@ def list_world_objects( List world object response containing the filtered list of world objects """ - return self._world_objects_client.list_world_objects( - object_types, time_start_point - ) + return self._world_objects_client.list_world_objects(object_types, time_start_point) diff --git a/spot_wrapper/testing/fixtures.py b/spot_wrapper/testing/fixtures.py index e68518c1..6eac9463 100644 --- a/spot_wrapper/testing/fixtures.py +++ b/spot_wrapper/testing/fixtures.py @@ -55,9 +55,7 @@ def fixture( def decorator(cls: typing.Type[BaseMockSpot]) -> typing.Callable: def fixturefunc(monkeypatch, **kwargs) -> typing.Iterator[SpotFixture]: - with concurrent.futures.ThreadPoolExecutor( - max_workers=max_workers - ) as thread_pool: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as thread_pool: server = grpc.server(thread_pool) port = server.add_insecure_port(f"{address}:0") with cls(**kwargs) as mock: @@ -79,9 +77,7 @@ def mock_secure_channel(target, _, *args, **kwargs): if "monkeypatch" not in sig.parameters: sig = sig.replace( parameters=( - inspect.Parameter( - "monkeypatch", inspect.Parameter.POSITIONAL_OR_KEYWORD - ), + inspect.Parameter("monkeypatch", inspect.Parameter.POSITIONAL_OR_KEYWORD), *sig.parameters.values(), ) ) diff --git a/spot_wrapper/testing/grpc.py b/spot_wrapper/testing/grpc.py index dc5692e4..e2757630 100644 --- a/spot_wrapper/testing/grpc.py +++ b/spot_wrapper/testing/grpc.py @@ -48,9 +48,7 @@ def method_handlers( if hasattr(handler, "_method_handlers"): yield from handler._method_handlers.items() - def add_generic_rpc_handlers( - self, handlers: typing.Iterable[grpc.GenericRpcHandler] - ) -> None: + def add_generic_rpc_handlers(self, handlers: typing.Iterable[grpc.GenericRpcHandler]) -> None: """Implements `grpc.Server.add_generic_rcp_handlers`.""" self.handlers.extend(handlers) @@ -107,28 +105,20 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: unqualified_name = handler.stream_stream.__name__ underlying_callable = handler.stream_stream if self.autospec and not implemented(underlying_callable): - underlying_callable = DeferredStreamRpcHandler( - underlying_callable - ) + underlying_callable = DeferredStreamRpcHandler(underlying_callable) self.needs_shutdown.append(underlying_callable) if self.autotrack: - underlying_callable = TrackingStreamStreamRpcHandler( - underlying_callable - ) + underlying_callable = TrackingStreamStreamRpcHandler(underlying_callable) if underlying_callable is not handler.stream_stream: setattr(self, unqualified_name, underlying_callable) else: unqualified_name = handler.unary_stream.__name__ underlying_callable = handler.unary_stream if self.autospec and not implemented(underlying_callable): - underlying_callable = DeferredStreamRpcHandler( - underlying_callable - ) + underlying_callable = DeferredStreamRpcHandler(underlying_callable) self.needs_shutdown.append(underlying_callable) if self.autotrack: - underlying_callable = TrackingUnaryStreamRpcHandler( - underlying_callable - ) + underlying_callable = TrackingUnaryStreamRpcHandler(underlying_callable) if underlying_callable is not handler.unary_stream: setattr(self, unqualified_name, underlying_callable) else: @@ -136,28 +126,20 @@ def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: unqualified_name = handler.stream_unary.__name__ underlying_callable = handler.stream_unary if self.autospec and not implemented(underlying_callable): - underlying_callable = DeferredUnaryRpcHandler( - underlying_callable - ) + underlying_callable = DeferredUnaryRpcHandler(underlying_callable) self.needs_shutdown.append(underlying_callable) if self.autotrack: - underlying_callable = TrackingStreamUnaryRpcHandler( - underlying_callable - ) + underlying_callable = TrackingStreamUnaryRpcHandler(underlying_callable) if underlying_callable is not handler.stream_unary: setattr(self, unqualified_name, underlying_callable) else: unqualified_name = handler.unary_unary.__name__ underlying_callable = handler.unary_unary if self.autospec and not implemented(underlying_callable): - underlying_callable = DeferredUnaryRpcHandler( - underlying_callable - ) + underlying_callable = DeferredUnaryRpcHandler(underlying_callable) self.needs_shutdown.append(underlying_callable) if self.autotrack: - underlying_callable = TrackingUnaryUnaryRpcHandler( - underlying_callable - ) + underlying_callable = TrackingUnaryUnaryRpcHandler(underlying_callable) if underlying_callable is not handler.unary_unary: setattr(self, unqualified_name, underlying_callable) @@ -190,9 +172,7 @@ def __init__(self, handler: typing.Callable) -> None: self.requests: typing.List = [] self.num_calls = 0 - def __call__( - self, request: typing.Any, context: grpc.ServicerContext - ) -> typing.Any: + def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Any: try: self.requests.append(request) return self.__wrapped__(request, context) @@ -208,9 +188,7 @@ def __init__(self, handler: typing.Callable) -> None: self.requests: typing.List = [] self.num_calls = 0 - def __call__( - self, request_iterator: typing.Iterator, context: grpc.ServicerContext - ) -> typing.Any: + def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerContext) -> typing.Any: try: request = list(request_iterator) self.requests.append(request) @@ -228,9 +206,7 @@ def __init__(self, handler: typing.Callable) -> None: self.requests: typing.List = [] self.num_calls = 0 - def __call__( - self, request: typing.Any, context: grpc.ServicerContext - ) -> typing.Iterator: + def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Iterator: try: self.requests.append(request) yield from self.__wrapped__(request, context) @@ -246,9 +222,7 @@ def __init__(self, handler: typing.Callable) -> None: self.requests: typing.List = [] self.num_calls = 0 - def __call__( - self, request_iterator: typing.Iterator, context: grpc.ServicerContext - ) -> typing.Iterator: + def __call__(self, request_iterator: typing.Iterator, context: grpc.ServicerContext) -> typing.Iterator: try: request = list(request_iterator) self.requests.append(request) @@ -335,9 +309,7 @@ def returns(self, response: typing.Any) -> None: self._completed = True self._completion.notify_all() - def fails( - self, code: grpc.StatusCode, details: typing.Optional[str] = None - ) -> None: + def fails(self, code: grpc.StatusCode, details: typing.Optional[str] = None) -> None: """Fails the call by setting an error `code` and optional `details`.""" with self._completion: if self._completed: @@ -371,9 +343,7 @@ def returns(self, response: typing.Any) -> None: """Specifies the next call will succeed with the given `response`.""" self._changequeue.put(lambda call: call.returns(response)) - def fails( - self, code: grpc.StatusCode, details: typing.Optional[str] = None - ) -> None: + def fails(self, code: grpc.StatusCode, details: typing.Optional[str] = None) -> None: """Specifies the next call will fail with given error `code` and `details`.""" self._changequeue.put(lambda call: call.fails(code, details)) @@ -387,9 +357,7 @@ def shutdown(self) -> None: while not self._callqueue.empty(): call = self._callqueue.get_nowait() if "PYTEST_CURRENT_TEST" in os.environ: - logging.warning( - f"{self.__name__} call not served, dropped during shutdown" - ) + logging.warning(f"{self.__name__} call not served, dropped during shutdown") call.fails(grpc.StatusCode.ABORTED, "call dropped") @property @@ -397,9 +365,7 @@ def pending(self) -> bool: """Whether a call is waiting to be served.""" return not self._callqueue.empty() - def serve( - self, timeout: typing.Optional[float] = None - ) -> typing.Optional["DeferredRpcHandler.Call"]: + def serve(self, timeout: typing.Optional[float] = None) -> typing.Optional["DeferredRpcHandler.Call"]: """ Serve the next pending call, if any. @@ -424,9 +390,7 @@ def future(self) -> "DeferredRpcHandler.Future": class DeferredUnaryRpcHandler(DeferredRpcHandler): """A gRPC any-unary handler that decouples invocation and computation execution paths.""" - def __call__( - self, request: typing.Any, context: grpc.ServicerContext - ) -> typing.Any: + def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Any: call = DeferredRpcHandler.Call(request, context) if not self._future.materialize(call): self._callqueue.put(call) @@ -439,9 +403,7 @@ def __call__( class DeferredStreamRpcHandler(DeferredRpcHandler): """A gRPC any-stream handler that decouples invocation and computation execution paths.""" - def __call__( - self, request: typing.Any, context: grpc.ServicerContext - ) -> typing.Any: + def __call__(self, request: typing.Any, context: grpc.ServicerContext) -> typing.Any: call = DeferredRpcHandler.Call(request, context) if not self._future.materialize(call): self._callqueue.put(call) diff --git a/spot_wrapper/testing/helpers.py b/spot_wrapper/testing/helpers.py index 58620aeb..4b0938b7 100644 --- a/spot_wrapper/testing/helpers.py +++ b/spot_wrapper/testing/helpers.py @@ -17,9 +17,7 @@ class GeneralizedDecorator: def wraps(wrapped: typing.Callable): def decorator(func: typing.Callable): class wrapper(GeneralizedDecorator): - def __call__( - self, *args: typing.Any, **kwargs: typing.Any - ) -> typing.Any: + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: return func(*args, **kwargs) return wrapper(wrapped) @@ -36,9 +34,7 @@ def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: raise NotImplementedError() -UnaryUnaryHandlerCallable = typing.Callable[ - [typing.Any, grpc.ServicerContext], typing.Any -] +UnaryUnaryHandlerCallable = typing.Callable[[typing.Any, grpc.ServicerContext], typing.Any] def enforce_matching_headers( @@ -51,12 +47,8 @@ def wrapper(request: typing.Any, context: grpc.ServicerContext) -> typing.Any: response = handler(request, context) if hasattr(request, "header") and hasattr(response, "header"): response.header.request_header.CopyFrom(request.header) - response.header.request_received_timestamp.CopyFrom( - request.header.request_timestamp - ) - response.header.error.code = ( - response.header.error.code or CommonError.CODE_OK - ) + response.header.request_received_timestamp.CopyFrom(request.header.request_timestamp) + response.header.error.code = response.header.error.code or CommonError.CODE_OK return response return wrapper diff --git a/spot_wrapper/testing/mocks/auth.py b/spot_wrapper/testing/mocks/auth.py index 300f87f5..33aee486 100644 --- a/spot_wrapper/testing/mocks/auth.py +++ b/spot_wrapper/testing/mocks/auth.py @@ -8,9 +8,7 @@ class MockAuthService(AuthServiceServicer): """A mock Spot authentication service.""" - def GetAuthToken( - self, request: GetAuthTokenRequest, context: grpc.ServicerContext - ) -> GetAuthTokenResponse: + def GetAuthToken(self, request: GetAuthTokenRequest, context: grpc.ServicerContext) -> GetAuthTokenResponse: response = GetAuthTokenResponse() response.status = GetAuthTokenResponse.Status.STATUS_OK response.token = "mock-token" diff --git a/spot_wrapper/testing/mocks/directory.py b/spot_wrapper/testing/mocks/directory.py index e89c17d3..67ed12ca 100644 --- a/spot_wrapper/testing/mocks/directory.py +++ b/spot_wrapper/testing/mocks/directory.py @@ -28,9 +28,7 @@ class MockDirectoryService(DirectoryServiceServicer): DEFAULT_SERVICES = { entry.type: entry for entry in [ - ServiceEntry( - name="auth", type="bosdyn.api.AuthService", authority="auth.spot.robot" - ), + ServiceEntry(name="auth", type="bosdyn.api.AuthService", authority="auth.spot.robot"), ServiceEntry( name="payload-registration", type="bosdyn.api.RobotIdService", diff --git a/spot_wrapper/testing/mocks/estop.py b/spot_wrapper/testing/mocks/estop.py index 54fd6e5a..1cd7cb77 100644 --- a/spot_wrapper/testing/mocks/estop.py +++ b/spot_wrapper/testing/mocks/estop.py @@ -44,21 +44,15 @@ def RegisterEstopEndpoint( ) -> RegisterEstopEndpointResponse: response = RegisterEstopEndpointResponse() response.request.CopyFrom(request) - estop_configurations = { - config.unique_id: config for config in self.estop_configurations - } + estop_configurations = {config.unique_id: config for config in self.estop_configurations} if request.target_config_id not in estop_configurations: - response.status = ( - RegisterEstopEndpointResponse.Status.STATUS_CONFIG_MISMATCH - ) + response.status = RegisterEstopEndpointResponse.Status.STATUS_CONFIG_MISMATCH return response estop_configuration = estop_configurations[request.target_config_id] if request.target_endpoint.unique_id: estop_endpoints = {ep.unique_id: ep for ep in estop_configuration.endpoints} if request.target_endpoint.unique_id not in estop_endpoints: - response.status = ( - RegisterEstopEndpointResponse.Status.STATUS_ENDPOINT_MISMATCH - ) + response.status = RegisterEstopEndpointResponse.Status.STATUS_ENDPOINT_MISMATCH return response estop_endpoint = estop_endpoints[request.target_endpoint.unique_id] else: @@ -76,40 +70,23 @@ def DeregisterEstopEndpoint( ) -> DeregisterEstopEndpointResponse: response = DeregisterEstopEndpointResponse() response.request.CopyFrom(request) - estop_configurations = { - config.unique_id: config for config in self.estop_configurations - } + estop_configurations = {config.unique_id: config for config in self.estop_configurations} if request.target_config_id not in estop_configurations: - response.status = ( - RegisterEstopEndpointResponse.Status.STATUS_CONFIG_MISMATCH - ) + response.status = RegisterEstopEndpointResponse.Status.STATUS_CONFIG_MISMATCH return response estop_configuration = estop_configurations[request.target_config_id] - estop_endpoint_indices = { - ep.unique_id: index - for index, ep in enumerate(estop_configuration.endpoints) - } + estop_endpoint_indices = {ep.unique_id: index for index, ep in enumerate(estop_configuration.endpoints)} if request.target_endpoint.unique_id not in estop_endpoint_indices: - response.status = ( - RegisterEstopEndpointResponse.Status.STATUS_ENDPOINT_MISMATCH - ) + response.status = RegisterEstopEndpointResponse.Status.STATUS_ENDPOINT_MISMATCH return response - del estop_configuration.endpoints[ - estop_endpoint_indices[request.target_endpoint.unique_id] - ] + del estop_configuration.endpoints[estop_endpoint_indices[request.target_endpoint.unique_id]] response.status = DeregisterEstopEndpointResponse.Status.STATUS_SUCCESS return response - def EstopCheckIn( - self, request: EstopCheckInRequest, context: grpc.ServicerContext - ) -> EstopCheckInResponse: + def EstopCheckIn(self, request: EstopCheckInRequest, context: grpc.ServicerContext) -> EstopCheckInResponse: response = EstopCheckInResponse() response.request.CopyFrom(request) - estop_endpoints = { - ep.unique_id: ep - for cfg in self.estop_configurations - for ep in cfg.endpoints - } + estop_endpoints = {ep.unique_id: ep for cfg in self.estop_configurations for ep in cfg.endpoints} if request.endpoint.unique_id not in estop_endpoints: response.status = EstopCheckInResponse.Status.STATUS_ENDPOINT_UNKNOWN return response @@ -117,38 +94,28 @@ def EstopCheckIn( response.challenge = (request.challenge or 1) + 1 return response - def GetEstopConfig( - self, request: GetEstopConfigRequest, context: grpc.ServicerContext - ) -> GetEstopConfigResponse: + def GetEstopConfig(self, request: GetEstopConfigRequest, context: grpc.ServicerContext) -> GetEstopConfigResponse: response = GetEstopConfigResponse() response.request.CopyFrom(request) if not request.target_config_id: response.active_config.CopyFrom(self.active_estop_configuration) return response - estop_configurations = { - config.unique_id: config for config in self.estop_configurations - } + estop_configurations = {config.unique_id: config for config in self.estop_configurations} if request.target_config_id not in estop_configurations: response.header.error.code = CommonError.CODE_INVALID_REQUEST return response response.active_config.CopyFrom(estop_configurations[request.target_config_id]) return response - def SetEstopConfig( - self, request: SetEstopConfigRequest, context: grpc.ServicerContext - ) -> SetEstopConfigResponse: + def SetEstopConfig(self, request: SetEstopConfigRequest, context: grpc.ServicerContext) -> SetEstopConfigResponse: response = SetEstopConfigResponse() response.request.CopyFrom(request) if request.target_config_id: - estop_configurations = { - config.unique_id: config for config in self.estop_configurations - } + estop_configurations = {config.unique_id: config for config in self.estop_configurations} if request.target_config_id not in estop_configurations: response.status = SetEstopConfigResponse.Status.STATUS_INVALID_ID return response - self.active_estop_configuration = estop_configurations[ - request.target_config_id - ] + self.active_estop_configuration = estop_configurations[request.target_config_id] else: self.active_estop_configuration = EstopConfig() self.active_estop_configuration.unique_id = next(self._config_id_generator) diff --git a/spot_wrapper/testing/mocks/keepalive.py b/spot_wrapper/testing/mocks/keepalive.py index 7b4fd3e4..b2e490c1 100644 --- a/spot_wrapper/testing/mocks/keepalive.py +++ b/spot_wrapper/testing/mocks/keepalive.py @@ -32,9 +32,7 @@ def __init__(self, **kwargs: typing.Any) -> None: def policies(self) -> typing.Iterable[LivePolicy]: return self._policies.values() - def ModifyPolicy( - self, request: ModifyPolicyRequest, context: grpc.ServicerContext - ) -> ModifyPolicyResponse: + def ModifyPolicy(self, request: ModifyPolicyRequest, context: grpc.ServicerContext) -> ModifyPolicyResponse: response = ModifyPolicyResponse() for policy_id in request.policy_ids_to_remove: @@ -57,9 +55,7 @@ def ModifyPolicy( response.status = ModifyPolicyResponse.Status.STATUS_OK return response - def CheckIn( - self, request: CheckInRequest, context: grpc.ServicerContext - ) -> CheckInResponse: + def CheckIn(self, request: CheckInRequest, context: grpc.ServicerContext) -> CheckInResponse: response = CheckInResponse() if request.policy_id not in self._policies: response.status = CheckInResponse.Status.STATUS_INVALID_POLICY_ID @@ -69,9 +65,7 @@ def CheckIn( response.status = CheckInResponse.Status.STATUS_OK return response - def GetStatus( - self, request: GetStatusRequest, context: grpc.ServicerContext - ) -> GetStatusResponse: + def GetStatus(self, request: GetStatusRequest, context: grpc.ServicerContext) -> GetStatusResponse: response = GetStatusResponse() response.status.extend(self._policies.values()) return response diff --git a/spot_wrapper/testing/mocks/lease.py b/spot_wrapper/testing/mocks/lease.py index 28630403..6372ea88 100644 --- a/spot_wrapper/testing/mocks/lease.py +++ b/spot_wrapper/testing/mocks/lease.py @@ -55,18 +55,14 @@ def __init__( self._leasable_resources[resource.resource] = leasable_resource self._latest_lease: typing.Optional[Lease] = None - def AcquireLease( - self, request: AcquireLeaseRequest, context: grpc.ServicerContext - ) -> AcquireLeaseResponse: + def AcquireLease(self, request: AcquireLeaseRequest, context: grpc.ServicerContext) -> AcquireLeaseResponse: response = AcquireLeaseResponse() if request.resource not in self._leasable_resources: response.status = AcquireLeaseResponse.Status.STATUS_INVALID_RESOURCE return response leasable_resource = self._leasable_resources[request.resource] if not leasable_resource.is_stale: - response.status = ( - AcquireLeaseResponse.Status.STATUS_RESOURCE_ALREADY_CLAIMED - ) + response.status = AcquireLeaseResponse.Status.STATUS_RESOURCE_ALREADY_CLAIMED response.lease_owner.CopyFrom(leasable_resource.lease_owner) return response leasable_resource.lease.client_names.append(request.header.client_name) @@ -80,9 +76,7 @@ def AcquireLease( response.status = AcquireLeaseResponse.Status.STATUS_OK return response - def TakeLease( - self, request: TakeLeaseRequest, context: grpc.ServicerContext - ) -> TakeLeaseResponse: + def TakeLease(self, request: TakeLeaseRequest, context: grpc.ServicerContext) -> TakeLeaseResponse: response = TakeLeaseResponse() if request.resource not in self._leasable_resources: response.status = TakeLeaseResponse.Status.STATUS_INVALID_RESOURCE @@ -99,9 +93,7 @@ def TakeLease( response.status = TakeLeaseResponse.Status.STATUS_OK return response - def ReturnLease( - self, request: ReturnLeaseRequest, context: grpc.ServicerContext - ) -> ReturnLeaseResponse: + def ReturnLease(self, request: ReturnLeaseRequest, context: grpc.ServicerContext) -> ReturnLeaseResponse: response = ReturnLeaseResponse() if request.lease.resource not in self._leasable_resources: response.status = ReturnLeaseResponse.Status.STATUS_INVALID_RESOURCE @@ -115,25 +107,19 @@ def ReturnLease( response.status = ReturnLeaseResponse.Status.STATUS_OK return response - def ListLeases( - self, request: ListLeasesRequest, context: grpc.ServicerContext - ) -> ListLeasesResponse: + def ListLeases(self, request: ListLeasesRequest, context: grpc.ServicerContext) -> ListLeasesResponse: response = ListLeasesResponse() response.resources.extend(self._leasable_resources.values()) response.resource_tree.CopyFrom(self._resource_tree) return response - def RetainLease( - self, request: RetainLeaseRequest, context: grpc.ServicerContext - ) -> RetainLeaseResponse: + def RetainLease(self, request: RetainLeaseRequest, context: grpc.ServicerContext) -> RetainLeaseResponse: response = RetainLeaseResponse() response.lease_use_result.attempted_lease.CopyFrom(request.lease) if self._latest_lease is not None: response.lease_use_result.latest_known_lease.CopyFrom(self._latest_lease) if request.lease.resource not in self._leasable_resources: - response.lease_use_result.status = ( - LeaseUseResult.Status.STATUS_INVALID_LEASE - ) + response.lease_use_result.status = LeaseUseResult.Status.STATUS_INVALID_LEASE return response leasable_resource = self._leasable_resources[request.lease.resource] if leasable_resource.is_stale: diff --git a/spot_wrapper/testing/mocks/license.py b/spot_wrapper/testing/mocks/license.py index 0dbde5fb..b22ad0de 100644 --- a/spot_wrapper/testing/mocks/license.py +++ b/spot_wrapper/testing/mocks/license.py @@ -18,9 +18,7 @@ class MockLicenseService(LicenseServiceServicer): It provides a license that never expires for all features. """ - def GetLicenseInfo( - self, request: GetLicenseInfoRequest, context: grpc.ServicerContext - ) -> GetLicenseInfoResponse: + def GetLicenseInfo(self, request: GetLicenseInfoRequest, context: grpc.ServicerContext) -> GetLicenseInfoResponse: response = GetLicenseInfoResponse() response.license.status = LicenseInfo.Status.STATUS_VALID response.license.id = "0123210" diff --git a/spot_wrapper/testing/mocks/payload_registration.py b/spot_wrapper/testing/mocks/payload_registration.py index d8ed44c3..07e50f13 100644 --- a/spot_wrapper/testing/mocks/payload_registration.py +++ b/spot_wrapper/testing/mocks/payload_registration.py @@ -58,9 +58,7 @@ def GetPayloadAuthToken( ) -> GetPayloadAuthTokenResponse: response = GetPayloadAuthTokenResponse() if request.payload_credentials.guid not in self._payloads: - response.status = ( - GetPayloadAuthTokenResponse.Status.STATUS_INVALID_CREDENTIALS - ) + response.status = GetPayloadAuthTokenResponse.Status.STATUS_INVALID_CREDENTIALS return response response.status = GetPayloadAuthTokenResponse.Status.STATUS_OK response.token = "mock-payload-token" diff --git a/spot_wrapper/testing/mocks/robot_id.py b/spot_wrapper/testing/mocks/robot_id.py index 2a2e50ff..3602f478 100644 --- a/spot_wrapper/testing/mocks/robot_id.py +++ b/spot_wrapper/testing/mocks/robot_id.py @@ -8,9 +8,7 @@ class MockRobotIdService(RobotIdServiceServicer): """A mock Spot robot id service.""" - def GetRobotId( - self, request: RobotIdRequest, context: grpc.ServicerContext - ) -> RobotIdResponse: + def GetRobotId(self, request: RobotIdRequest, context: grpc.ServicerContext) -> RobotIdResponse: response = RobotIdResponse() response.robot_id.serial_number = "1234567890" response.robot_id.species = "mock-spot" diff --git a/spot_wrapper/testing/mocks/robot_state.py b/spot_wrapper/testing/mocks/robot_state.py index 41efd028..ec3b3a42 100644 --- a/spot_wrapper/testing/mocks/robot_state.py +++ b/spot_wrapper/testing/mocks/robot_state.py @@ -41,16 +41,12 @@ def __init__(self, **kwargs: typing.Any) -> None: self.robot_metrics = RobotMetrics() self.hardware_configuration = HardwareConfiguration() - def GetRobotState( - self, request: RobotStateRequest, context: grpc.ServicerContext - ) -> RobotStateResponse: + def GetRobotState(self, request: RobotStateRequest, context: grpc.ServicerContext) -> RobotStateResponse: response = RobotStateResponse() response.robot_state.CopyFrom(self.robot_state) return response - def GetRobotMetrics( - self, request: RobotMetricsRequest, context: grpc.ServicerContext - ) -> RobotMetricsResponse: + def GetRobotMetrics(self, request: RobotMetricsRequest, context: grpc.ServicerContext) -> RobotMetricsResponse: response = RobotMetricsResponse() response.robot_metrics.CopyFrom(self.robot_metrics) return response diff --git a/spot_wrapper/testing/mocks/time_sync.py b/spot_wrapper/testing/mocks/time_sync.py index 2a00bf54..2fd15bd9 100644 --- a/spot_wrapper/testing/mocks/time_sync.py +++ b/spot_wrapper/testing/mocks/time_sync.py @@ -16,9 +16,7 @@ class MockTimeSyncService(TimeSyncServiceServicer): Always perfect clock synchronization. """ - def TimeSyncUpdate( - self, request: TimeSyncUpdateRequest, context: grpc.ServicerContext - ) -> TimeSyncUpdateResponse: + def TimeSyncUpdate(self, request: TimeSyncUpdateRequest, context: grpc.ServicerContext) -> TimeSyncUpdateResponse: response = TimeSyncUpdateResponse() response.state.status = TimeSyncState.STATUS_OK response.state.best_estimate.clock_skew.seconds = 0 diff --git a/spot_wrapper/tests/test_graph_nav_util.py b/spot_wrapper/tests/test_graph_nav_util.py index 18ba04d2..7974941c 100755 --- a/spot_wrapper/tests/test_graph_nav_util.py +++ b/spot_wrapper/tests/test_graph_nav_util.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -import pytest import logging from bosdyn.api.graph_nav import map_pb2 @@ -18,14 +17,8 @@ def __init__(self) -> None: class TestGraphNavUtilShortCode: def test_id_to_short_code(self): - assert ( - graph_nav_util._id_to_short_code("ebony-pug-mUzxLNq.TkGlVIxga+UKAQ==") - == "ep" - ) - assert ( - graph_nav_util._id_to_short_code("erose-simian-sug9xpxhCxgft7Mtbhr98A==") - == "es" - ) + assert graph_nav_util._id_to_short_code("ebony-pug-mUzxLNq.TkGlVIxga+UKAQ==") == "ep" + assert graph_nav_util._id_to_short_code("erose-simian-sug9xpxhCxgft7Mtbhr98A==") == "es" class TestGraphNavUtilFindUniqueWaypointId: @@ -35,26 +28,11 @@ def test_short_code(self): self.graph = map_pb2.Graph() self.name_to_id = {"ABCDE": "Node1"} # Test normal short code - assert ( - graph_nav_util._find_unique_waypoint_id( - "AC", self.graph, self.name_to_id, self.logger - ) - == "AC" - ) + assert graph_nav_util._find_unique_waypoint_id("AC", self.graph, self.name_to_id, self.logger) == "AC" # Test annotation name that is known - assert ( - graph_nav_util._find_unique_waypoint_id( - "ABCDE", self.graph, self.name_to_id, self.logger - ) - == "Node1" - ) + assert graph_nav_util._find_unique_waypoint_id("ABCDE", self.graph, self.name_to_id, self.logger) == "Node1" # Test annotation name that is unknown - assert ( - graph_nav_util._find_unique_waypoint_id( - "ABCDEF", self.graph, self.name_to_id, self.logger - ) - == "ABCDEF" - ) + assert graph_nav_util._find_unique_waypoint_id("ABCDEF", self.graph, self.name_to_id, self.logger) == "ABCDEF" def test_short_code_with_graph(self): # Set up @@ -64,27 +42,12 @@ def test_short_code_with_graph(self): # Test short code that is in graph self.graph.waypoints.add(id="AB-CD-EF") - assert ( - graph_nav_util._find_unique_waypoint_id( - "AC", self.graph, self.name_to_id, self.logger - ) - == "AB-CD-EF" - ) + assert graph_nav_util._find_unique_waypoint_id("AC", self.graph, self.name_to_id, self.logger) == "AB-CD-EF" # Test short code that is not in graph - assert ( - graph_nav_util._find_unique_waypoint_id( - "AD", self.graph, self.name_to_id, self.logger - ) - == "AD" - ) + assert graph_nav_util._find_unique_waypoint_id("AD", self.graph, self.name_to_id, self.logger) == "AD" # Test multiple waypoints with the same short code self.graph.waypoints.add(id="AB-CD-EF-1") - assert ( - graph_nav_util._find_unique_waypoint_id( - "AC", self.graph, self.name_to_id, self.logger - ) - == "AC" - ) + assert graph_nav_util._find_unique_waypoint_id("AC", self.graph, self.name_to_id, self.logger) == "AC" class TestGraphNavUtilUpdateWaypointsEdges: @@ -94,9 +57,7 @@ def test_empty_graph(self): # Test empty graph self.graph = map_pb2.Graph() self.localization_id = "" - graph_nav_util._update_waypoints_and_edges( - self.graph, self.localization_id, self.logger - ) + graph_nav_util._update_waypoints_and_edges(self.graph, self.localization_id, self.logger) assert len(self.graph.waypoints) == 0 assert len(self.graph.edges) == 0 @@ -109,9 +70,7 @@ def test_one_waypoint(self): new_waypoint = map_pb2.Waypoint() new_waypoint.id = "ABCDE" new_waypoint.annotations.name = "Node1" - self.graph.waypoints.add( - id=new_waypoint.id, annotations=new_waypoint.annotations - ) + self.graph.waypoints.add(id=new_waypoint.id, annotations=new_waypoint.annotations) self.name_to_id, self.edges = graph_nav_util._update_waypoints_and_edges( self.graph, self.localization_id, self.logger ) @@ -130,14 +89,10 @@ def test_two_waypoints_with_edge(self): new_waypoint = map_pb2.Waypoint() new_waypoint.id = "ABCDE" new_waypoint.annotations.name = "Node1" - self.graph.waypoints.add( - id=new_waypoint.id, annotations=new_waypoint.annotations - ) + self.graph.waypoints.add(id=new_waypoint.id, annotations=new_waypoint.annotations) new_waypoint.id = "DE" new_waypoint.annotations.name = "Node2" - self.graph.waypoints.add( - id=new_waypoint.id, annotations=new_waypoint.annotations - ) + self.graph.waypoints.add(id=new_waypoint.id, annotations=new_waypoint.annotations) new_edge = map_pb2.Edge.Id(from_waypoint="ABCDE", to_waypoint="DE") self.graph.edges.add(id=new_edge) @@ -161,14 +116,10 @@ def test_two_waypoints_with_edge_and_localization(self): new_waypoint = map_pb2.Waypoint() new_waypoint.id = "ABCDE" new_waypoint.annotations.name = "Node1" - self.graph.waypoints.add( - id=new_waypoint.id, annotations=new_waypoint.annotations - ) + self.graph.waypoints.add(id=new_waypoint.id, annotations=new_waypoint.annotations) new_waypoint.id = "DE" new_waypoint.annotations.name = "Node2" - self.graph.waypoints.add( - id=new_waypoint.id, annotations=new_waypoint.annotations - ) + self.graph.waypoints.add(id=new_waypoint.id, annotations=new_waypoint.annotations) new_edge = map_pb2.Edge.Id(from_waypoint="ABCDE", to_waypoint="DE") self.graph.edges.add(id=new_edge) diff --git a/spot_wrapper/tests/test_wrapper.py b/spot_wrapper/tests/test_wrapper.py index 3024592b..cad189ee 100644 --- a/spot_wrapper/tests/test_wrapper.py +++ b/spot_wrapper/tests/test_wrapper.py @@ -35,21 +35,15 @@ def __init__(self, request) -> None: manipulator_state = self.robot_state.manipulator_state manipulator_state.is_gripper_holding_item = True - def PowerCommand( - self, request: PowerCommandRequest, context: grpc.ServicerContext - ) -> PowerCommandResponse: + def PowerCommand(self, request: PowerCommandRequest, context: grpc.ServicerContext) -> PowerCommandResponse: # Provide custom bosdyn.api.PowerService/PowerCommand implementation. response = PowerCommandResponse() power_state = self.robot_state.power_state if request.request == PowerCommandRequest.Request.REQUEST_ON_MOTORS: - power_state.motor_power_state = ( - PowerState.MotorPowerState.MOTOR_POWER_STATE_ON - ) + power_state.motor_power_state = PowerState.MotorPowerState.MOTOR_POWER_STATE_ON response.status = PowerCommandStatus.STATUS_SUCCESS elif request.request == PowerCommandRequest.Request.REQUEST_OFF_MOTORS: - power_state.motor_power_state = ( - PowerState.MotorPowerState.MOTOR_POWER_STATE_OFF - ) + power_state.motor_power_state = PowerState.MotorPowerState.MOTOR_POWER_STATE_OFF response.status = PowerCommandStatus.STATUS_SUCCESS else: response.status = PowerCommandStatus.STATUS_INTERNAL_ERROR @@ -69,9 +63,7 @@ def simple_spot_wrapper(simple_spot: SpotFixture) -> SpotWrapper: ) -def test_wrapper_setup( - simple_spot: SpotFixture, simple_spot_wrapper: SpotWrapper -) -> None: +def test_wrapper_setup(simple_spot: SpotFixture, simple_spot_wrapper: SpotWrapper) -> None: # spot_wrapper.testing.mocks.MockSpot dummy services enable basic usage. assert simple_spot_wrapper.is_valid diff --git a/spot_wrapper/wrapper.py b/spot_wrapper/wrapper.py index 406daad3..60c73c93 100644 --- a/spot_wrapper/wrapper.py +++ b/spot_wrapper/wrapper.py @@ -5,6 +5,7 @@ import bosdyn.client.auth from bosdyn.api import ( + basic_command_pb2, lease_pb2, manipulation_api_pb2, point_cloud_pb2, @@ -13,6 +14,17 @@ world_object_pb2, ) from bosdyn.api.spot import robot_command_pb2 as spot_command_pb2 +from bosdyn.api.spot.choreography_sequence_pb2 import ( + Animation, + ChoreographySequence, + ChoreographyStatusResponse, + StartRecordingStateResponse, + StopRecordingStateResponse, + UploadChoreographyResponse, +) +from bosdyn.choreography.client.choreography import ( + ChoreographyClient, +) from bosdyn.client import ( ResponseError, RpcError, @@ -42,31 +54,12 @@ from bosdyn.client.spot_check import SpotCheckClient from bosdyn.client.time_sync import TimeSyncEndpoint from bosdyn.client.world_object import WorldObjectClient - -from bosdyn.choreography.client.choreography import ( - ChoreographyClient, -) -from bosdyn.api.spot.choreography_sequence_pb2 import ( - Animation, - ChoreographySequence, - ChoreographyStatusResponse, - StartRecordingStateResponse, - StopRecordingStateResponse, - UploadChoreographyResponse, -) -from .spot_dance import SpotDance - from bosdyn.geometry import EulerZXY - -SPOT_CLIENT_NAME = "ros_spot" -MAX_COMMAND_DURATION = 1e5 -VELODYNE_SERVICE_NAME = "velodyne-point-cloud" - -from bosdyn.api import basic_command_pb2 from google.protobuf.timestamp_pb2 import Timestamp from .spot_arm import SpotArm from .spot_check import SpotCheck +from .spot_dance import SpotDance from .spot_docking import SpotDocking from .spot_eap import SpotEAP from .spot_graph_nav import SpotGraphNav @@ -74,6 +67,10 @@ from .spot_world_objects import SpotWorldObjects from .wrapper_helpers import ClaimAndPowerDecorator, RobotCommandData, RobotState +SPOT_CLIENT_NAME = "ros_spot" +MAX_COMMAND_DURATION = 1e5 +VELODYNE_SERVICE_NAME = "velodyne-point-cloud" + def robotToLocalTime(timestamp: Timestamp, robot: Robot) -> Timestamp: """Takes a timestamp and an estimated skew and return seconds and nano seconds in local time @@ -110,7 +107,8 @@ def __init__(self, message="Spot arm not available"): class AsyncRobotState(AsyncPeriodicQuery): - """Class to get robot state at regular intervals. get_robot_state_async query sent to the robot at every tick. Callback registered to defined callback function. + """Class to get robot state at regular intervals. get_robot_state_async query sent to the robot at every tick. + Callback registered to defined callback function. Attributes: client: The Client to a service on the robot @@ -120,9 +118,7 @@ class AsyncRobotState(AsyncPeriodicQuery): """ def __init__(self, client, logger, rate, callback): - super(AsyncRobotState, self).__init__( - "robot-state", client, logger, period_sec=1.0 / max(rate, 1.0) - ) + super(AsyncRobotState, self).__init__("robot-state", client, logger, period_sec=1.0 / max(rate, 1.0)) self._callback = None if rate > 0.0: self._callback = callback @@ -135,7 +131,8 @@ def _start_query(self): class AsyncMetrics(AsyncPeriodicQuery): - """Class to get robot metrics at regular intervals. get_robot_metrics_async query sent to the robot at every tick. Callback registered to defined callback function. + """Class to get robot metrics at regular intervals. get_robot_metrics_async query sent to the robot at every tick. + Callback registered to defined callback function. Attributes: client: The Client to a service on the robot @@ -145,9 +142,7 @@ class AsyncMetrics(AsyncPeriodicQuery): """ def __init__(self, client, logger, rate, callback): - super(AsyncMetrics, self).__init__( - "robot-metrics", client, logger, period_sec=1.0 / max(rate, 1.0) - ) + super(AsyncMetrics, self).__init__("robot-metrics", client, logger, period_sec=1.0 / max(rate, 1.0)) self._callback = None if rate > 0.0: self._callback = callback @@ -160,7 +155,8 @@ def _start_query(self): class AsyncLease(AsyncPeriodicQuery): - """Class to get lease state at regular intervals. list_leases_async query sent to the robot at every tick. Callback registered to defined callback function. + """Class to get lease state at regular intervals. list_leases_async query sent to the robot at every tick. + Callback registered to defined callback function. Attributes: client: The Client to a service on the robot @@ -170,9 +166,7 @@ class AsyncLease(AsyncPeriodicQuery): """ def __init__(self, client, logger, rate, callback): - super(AsyncLease, self).__init__( - "lease", client, logger, period_sec=1.0 / max(rate, 1.0) - ) + super(AsyncLease, self).__init__("lease", client, logger, period_sec=1.0 / max(rate, 1.0)) self._callback = None if rate > 0.0: self._callback = callback @@ -209,19 +203,13 @@ def __init__( def _start_query(self) -> None: if self._spot_wrapper.last_stand_command is not None: try: - response = self._client.robot_command_feedback( - self._spot_wrapper.last_stand_command - ) - status = ( - response.feedback.synchronized_feedback.mobility_command_feedback.stand_feedback.status - ) + response = self._client.robot_command_feedback(self._spot_wrapper.last_stand_command) + status = response.feedback.synchronized_feedback.mobility_command_feedback.stand_feedback.status self._spot_wrapper.is_sitting = False if status == basic_command_pb2.StandCommand.Feedback.STATUS_IS_STANDING: self._spot_wrapper.is_standing = True self._spot_wrapper.last_stand_command = None - elif ( - status == basic_command_pb2.StandCommand.Feedback.STATUS_IN_PROGRESS - ): + elif status == basic_command_pb2.StandCommand.Feedback.STATUS_IN_PROGRESS: self._spot_wrapper.is_standing = False else: self._logger.warning("Stand command in unknown state") @@ -233,9 +221,7 @@ def _start_query(self) -> None: if self._spot_wrapper.last_sit_command is not None: try: self._spot_wrapper.is_standing = False - response = self._client.robot_command_feedback( - self._spot_wrapper.last_sit_command - ) + response = self._client.robot_command_feedback(self._spot_wrapper.last_sit_command) if ( response.feedback.synchronized_feedback.mobility_command_feedback.sit_feedback.status == basic_command_pb2.SitCommand.Feedback.STATUS_IS_SITTING @@ -258,49 +244,31 @@ def _start_query(self) -> None: if self._spot_wrapper.last_trajectory_command is not None: try: - response = self._client.robot_command_feedback( - self._spot_wrapper.last_trajectory_command - ) + response = self._client.robot_command_feedback(self._spot_wrapper.last_trajectory_command) status = ( response.feedback.synchronized_feedback.mobility_command_feedback.se2_trajectory_feedback.status ) # STATUS_AT_GOAL always means that the robot reached the goal. If the trajectory command did not # request precise positioning, then STATUS_NEAR_GOAL also counts as reaching the goal - if ( - status - == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_AT_GOAL - or ( - status - == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_NEAR_GOAL - and not self._spot_wrapper.last_trajectory_command_precise - ) + if status == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_AT_GOAL or ( + status == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_NEAR_GOAL + and not self._spot_wrapper.last_trajectory_command_precise ): self._spot_wrapper.at_goal = True # Clear the command once at the goal self._spot_wrapper.last_trajectory_command = None self._spot_wrapper._trajectory_status_unknown = False - elif ( - status - == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_GOING_TO_GOAL - ): + elif status == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_GOING_TO_GOAL: is_moving = True - elif ( - status - == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_NEAR_GOAL - ): + elif status == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_NEAR_GOAL: is_moving = True self._spot_wrapper.near_goal = True - elif ( - status - == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_UNKNOWN - ): + elif status == basic_command_pb2.SE2TrajectoryCommand.Feedback.STATUS_UNKNOWN: self._spot_wrapper.trajectory_status_unknown = True self._spot_wrapper.last_trajectory_command = None else: self._logger.error( - "Received trajectory command status outside of expected range, value is {}".format( - status - ) + "Received trajectory command status outside of expected range, value is {}".format(status) ) self._spot_wrapper.last_trajectory_command = None except (ResponseError, RpcError) as e: @@ -334,9 +302,7 @@ class AsyncEStopMonitor(AsyncPeriodicQuery): """ def __init__(self, client, logger, rate, spot_wrapper): - super(AsyncEStopMonitor, self).__init__( - "estop_alive", client, logger, period_sec=1.0 / rate - ) + super(AsyncEStopMonitor, self).__init__("estop_alive", client, logger, period_sec=1.0 / rate) self._spot_wrapper = spot_wrapper def _start_query(self): @@ -345,20 +311,10 @@ def _start_query(self): return last_estop_status = self._spot_wrapper._estop_keepalive.status_queue.queue[-1] - if ( - last_estop_status[0] - == self._spot_wrapper._estop_keepalive.KeepAliveStatus.ERROR - ): - self._logger.error( - "Estop keepalive has an error: {}".format(last_estop_status[1]) - ) - elif ( - last_estop_status - == self._spot_wrapper._estop_keepalive.KeepAliveStatus.DISABLED - ): - self._logger.error( - "Estop keepalive is disabled: {}".format(last_estop_status[1]) - ) + if last_estop_status[0] == self._spot_wrapper._estop_keepalive.KeepAliveStatus.ERROR: + self._logger.error("Estop keepalive has an error: {}".format(last_estop_status[1])) + elif last_estop_status == self._spot_wrapper._estop_keepalive.KeepAliveStatus.DISABLED: + self._logger.error("Estop keepalive is disabled: {}".format(last_estop_status[1])) else: # estop keepalive is ok pass @@ -394,8 +350,10 @@ def __init__( start_estop: If true, the wrapper will be an estop endpoint estop_timeout: Timeout for the estop in seconds. The SDK will check in with the wrapper at a rate of estop_timeout/3 and if there is no communication the robot will execute a gentle stop. - rates: Dictionary of rates to apply when retrieving various data from the robot # TODO this should be an object to be unambiguous - callbacks: Dictionary of callbacks which should be called when certain data is retrieved # TODO this should be an object to be unambiguous + rates: Dictionary of rates to apply when retrieving various data from the robot + # TODO this should be an object to be unambiguous + callbacks: Dictionary of callbacks which should be called when certain data is retrieved + # TODO this should be an object to be unambiguous use_take_lease: Use take instead of acquire to get leases. This will forcefully take the lease from any other lease owner. get_lease_on_action: If true, attempt to acquire a lease when performing an action which requires a @@ -413,9 +371,7 @@ def __init__( self._rates = rates or {} self._callbacks = callbacks or {} self._use_take_lease = use_take_lease - self._claim_decorator = ClaimAndPowerDecorator( - self.power_on, self.claim, get_lease_on_action - ) + self._claim_decorator = ClaimAndPowerDecorator(self.power_on, self.claim, get_lease_on_action) self.decorate_functions() self._continually_try_stand = continually_try_stand self._rgb_cameras = rgb_cameras @@ -453,9 +409,7 @@ def __init__( self._robot, self._payload_credentials_file, self._logger ) else: - authenticated = self.authenticate( - self._robot, self._username, self._password, self._logger - ) + authenticated = self.authenticate(self._robot, self._username, self._password, self._logger) if not authenticated: self._valid = False @@ -470,43 +424,19 @@ def __init__( initialised = False while not initialised: try: - self._robot_state_client = self._robot.ensure_client( - RobotStateClient.default_service_name - ) - self._world_objects_client = self._robot.ensure_client( - WorldObjectClient.default_service_name - ) - self._robot_command_client = self._robot.ensure_client( - RobotCommandClient.default_service_name - ) - self._graph_nav_client = self._robot.ensure_client( - GraphNavClient.default_service_name - ) - self._map_processing_client = self._robot.ensure_client( - MapProcessingServiceClient.default_service_name - ) - self._power_client = self._robot.ensure_client( - PowerClient.default_service_name - ) - self._lease_client = self._robot.ensure_client( - LeaseClient.default_service_name - ) + self._robot_state_client = self._robot.ensure_client(RobotStateClient.default_service_name) + self._world_objects_client = self._robot.ensure_client(WorldObjectClient.default_service_name) + self._robot_command_client = self._robot.ensure_client(RobotCommandClient.default_service_name) + self._graph_nav_client = self._robot.ensure_client(GraphNavClient.default_service_name) + self._map_processing_client = self._robot.ensure_client(MapProcessingServiceClient.default_service_name) + self._power_client = self._robot.ensure_client(PowerClient.default_service_name) + self._lease_client = self._robot.ensure_client(LeaseClient.default_service_name) self._lease_wallet = self._lease_client.lease_wallet - self._image_client = self._robot.ensure_client( - ImageClient.default_service_name - ) - self._estop_client = self._robot.ensure_client( - EstopClient.default_service_name - ) - self._docking_client = self._robot.ensure_client( - DockingClient.default_service_name - ) - self._spot_check_client = self._robot.ensure_client( - SpotCheckClient.default_service_name - ) - self._license_client = self._robot.ensure_client( - LicenseClient.default_service_name - ) + self._image_client = self._robot.ensure_client(ImageClient.default_service_name) + self._estop_client = self._robot.ensure_client(EstopClient.default_service_name) + self._docking_client = self._robot.ensure_client(DockingClient.default_service_name) + self._spot_check_client = self._robot.ensure_client(SpotCheckClient.default_service_name) + self._license_client = self._robot.ensure_client(LicenseClient.default_service_name) if self._robot.has_arm(): self._gripper_cam_param_client = self._robot.ensure_client( GripperCameraParamClient.default_service_name @@ -514,22 +444,18 @@ def __init__( else: self._gripper_cam_param_client = None - if self._license_client.get_feature_enabled( - [ChoreographyClient.license_name] - )[ChoreographyClient.license_name]: + if self._license_client.get_feature_enabled([ChoreographyClient.license_name])[ + ChoreographyClient.license_name + ]: self._is_licensed_for_choreography = True - self._choreography_client = self._robot.ensure_client( - ChoreographyClient.default_service_name - ) + self._choreography_client = self._robot.ensure_client(ChoreographyClient.default_service_name) else: self._logger.info("Robot is not licensed for choreography") self._is_licensed_for_choreography = False self._choreography_client = None try: - self._point_cloud_client = self._robot.ensure_client( - VELODYNE_SERVICE_NAME - ) + self._point_cloud_client = self._robot.ensure_client(VELODYNE_SERVICE_NAME) except UnregisteredServiceError: self._point_cloud_client = None self._logger.info("No velodyne point cloud service is available.") @@ -547,9 +473,7 @@ def __init__( sleep_secs = 15 self._logger.warning( "Unable to create client service: {}. This usually means the robot hasn't " - "finished booting yet. Will wait {} seconds and try again.".format( - e, sleep_secs - ) + "finished booting yet. Will wait {} seconds and try again.".format(e, sleep_secs) ) time.sleep(sleep_secs) @@ -573,12 +497,8 @@ def __init__( max(0.0, self._rates.get("lease", 0.0)), self._callbacks.get("lease", None), ) - self._idle_task = AsyncIdle( - self._robot_command_client, self._logger, 10.0, self - ) - self._estop_monitor = AsyncEStopMonitor( - self._estop_client, self._logger, 20.0, self - ) + self._idle_task = AsyncIdle(self._robot_command_client, self._logger, 10.0, self) + self._estop_monitor = AsyncEStopMonitor(self._estop_client, self._logger, 20.0, self) self._estop_endpoint = None self._estop_keepalive = None @@ -666,9 +586,7 @@ def __init__( self._async_tasks = AsyncTasks(robot_tasks) if self._is_licensed_for_choreography: - self._spot_dance = SpotDance( - self._robot, self._choreography_client, self._logger - ) + self._spot_dance = SpotDance(self._robot, self._choreography_client, self._logger) self._robot_id = None self._lease = None @@ -700,14 +618,10 @@ def decorate_functions(self): self.toggle_power, ] - self._claim_decorator.decorate_functions( - self, decorated_funcs, decorated_funcs_no_power - ) + self._claim_decorator.decorate_functions(self, decorated_funcs, decorated_funcs_no_power) @staticmethod - def authenticate( - robot: Robot, username: str, password: str, logger: logging.Logger - ) -> bool: + def authenticate(robot: Robot, username: str, password: str, logger: logging.Logger) -> bool: """ Authenticate with a robot through the bosdyn API. A blocking function which will wait until authenticated (if the robot is still booting) or login fails @@ -763,13 +677,9 @@ def authenticate_from_payload_credentials( authenticated = False while not authenticated: try: - logger.info( - "Trying to authenticate with robot from payload credentials..." - ) + logger.info("Trying to authenticate with robot from payload credentials...") robot.authenticate_from_payload_credentials( - *bosdyn.client.util.read_payload_credentials( - payload_credentials_file - ) + *bosdyn.client.util.read_payload_credentials(payload_credentials_file) ) robot.time_sync.wait_for_sync(10) logger.info("Successfully authenticated.") @@ -1003,10 +913,7 @@ def claim(self) -> typing.Tuple[bool, str]: """Get a lease for the robot, a handle on the estop endpoint, and the ID of the robot.""" if self.lease is not None: for resource in self.lease: - if ( - resource.resource == "all-leases" - and SPOT_CLIENT_NAME in resource.lease_owner.client_name - ): + if resource.resource == "all-leases" and SPOT_CLIENT_NAME in resource.lease_owner.client_name: return True, "We already claimed the lease" try: @@ -1035,9 +942,7 @@ def updateTasks(self) -> None: def resetEStop(self) -> None: """Get keepalive for eStop""" - self._estop_endpoint = EstopEndpoint( - self._estop_client, SPOT_CLIENT_NAME, self._estop_timeout - ) + self._estop_endpoint = EstopEndpoint(self._estop_client, SPOT_CLIENT_NAME, self._estop_timeout) self._estop_endpoint.force_simple_setup() # Set this endpoint as the robot's sole estop. self._estop_keepalive = EstopKeepAlive(self._estop_endpoint) @@ -1045,7 +950,8 @@ def assertEStop(self, severe: bool = True) -> typing.Tuple[bool, str]: """Forces the robot into eStop state. Args: - severe: Default True - If true, will cut motor power immediately. If false, will try to settle the robot on the ground first + severe: Default True - If true, will cut motor power immediately. If false, will try to settle the robot + on the ground first """ try: if severe: @@ -1131,9 +1037,7 @@ def _robot_command( self._logger.error(f"Unable to execute robot command: {e}") return False, str(e), None - def _manipulation_request( - self, request_proto, end_time_secs=None, timesync_endpoint=None - ): + def _manipulation_request(self, request_proto, end_time_secs=None, timesync_endpoint=None): """Generic function for sending requests to the manipulation api of a robot. Args: @@ -1189,9 +1093,7 @@ def simple_stand(self, monitor_command: bool = True) -> typing.Tuple[bool, str]: Returns: Tuple of bool success and a string message """ - response = self._robot_command( - RobotCommandBuilder.synchro_stand_command(params=self._mobility_params) - ) + response = self._robot_command(RobotCommandBuilder.synchro_stand_command(params=self._mobility_params)) if monitor_command: self.last_stand_command = response[2] return response[0], response[1] @@ -1225,15 +1127,11 @@ def stand( # If any of the orientation parameters are nonzero use them to pose the body body_orientation = EulerZXY(yaw=body_yaw, pitch=body_pitch, roll=body_roll) response = self._robot_command( - RobotCommandBuilder.synchro_stand_command( - body_height=body_height, footprint_R_body=body_orientation - ) + RobotCommandBuilder.synchro_stand_command(body_height=body_height, footprint_R_body=body_orientation) ) else: # Otherwise just use the mobility params - response = self._robot_command( - RobotCommandBuilder.synchro_stand_command(params=self._mobility_params) - ) + response = self._robot_command(RobotCommandBuilder.synchro_stand_command(params=self._mobility_params)) if monitor_command: self.last_stand_command = response[2] @@ -1250,9 +1148,7 @@ def battery_change_pose(self, dir_hint: int = 1) -> typing.Tuple[bool, str]: Tuple of bool success and a string message """ if self.is_sitting: - response = self._robot_command( - RobotCommandBuilder.battery_change_pose_command(dir_hint) - ) + response = self._robot_command(RobotCommandBuilder.battery_change_pose_command(dir_hint)) return response[0], response[1] return False, "Call sit before trying to roll over" @@ -1266,9 +1162,7 @@ def safe_power_off(self) -> typing.Tuple[bool, str]: response = self._robot_command(RobotCommandBuilder.safe_power_off_command()) return response[0], response[1] - def clear_behavior_fault( - self, fault_id: int - ) -> typing.Tuple[bool, str, typing.Optional[bool]]: + def clear_behavior_fault(self, fault_id: int) -> typing.Tuple[bool, str, typing.Optional[bool]]: """ Clear the behavior fault defined by the given id. @@ -1276,9 +1170,7 @@ def clear_behavior_fault( Tuple of bool success, string message, and bool indicating whether the status was cleared """ try: - rid = self._robot_command_client.clear_behavior_fault( - behavior_fault_id=fault_id, lease=None - ) + rid = self._robot_command_client.clear_behavior_fault(behavior_fault_id=fault_id, lease=None) return True, "Success", rid except Exception as e: return False, f"Exception while clearing behavior fault: {e}", None @@ -1292,7 +1184,8 @@ def power_on(self) -> typing.Tuple[bool, str]: """ # Don't bother trying to power on if we are already powered on if not self.check_is_powered_on(): - # If we are requested to start the estop, we have to acquire it when powering on. Ignore if estop is already acquired. + # If we are requested to start the estop, we have to acquire it when powering on. + # Ignore if estop is already acquired. if self._start_estop and self._estop_keepalive is None: self.resetEStop() try: @@ -1305,9 +1198,7 @@ def power_on(self) -> typing.Tuple[bool, str]: return True, "Was already powered on" - def set_mobility_params( - self, mobility_params: spot_command_pb2.MobilityParams - ) -> None: + def set_mobility_params(self, mobility_params: spot_command_pb2.MobilityParams) -> None: """Set Params for mobility and movement Args: @@ -1330,16 +1221,15 @@ def velocity_cmd( v_x: Velocity in the X direction in meters v_y: Velocity in the Y direction in meters v_rot: Angular velocity around the Z axis in radians - cmd_duration: (optional) Time-to-live for the command in seconds. Default is 125ms (assuming 10Hz command rate). + cmd_duration: (optional) Time-to-live for the command in seconds. Default is 125ms (assuming 10Hz command + rate). Returns: Tuple of bool success and a string message """ end_time = time.time() + cmd_duration response = self._robot_command( - RobotCommandBuilder.synchro_velocity_command( - v_x=v_x, v_y=v_y, v_rot=v_rot, params=self._mobility_params - ), + RobotCommandBuilder.synchro_velocity_command(v_x=v_x, v_y=v_y, v_rot=v_rot, params=self._mobility_params), end_time_secs=end_time, timesync_endpoint=self._robot.time_sync.endpoint, ) @@ -1422,9 +1312,7 @@ def trajectory_cmd( self.last_trajectory_command = response[2] return response[0], response[1] - def robot_command( - self, robot_command: robot_command_pb2.RobotCommand - ) -> typing.Tuple[bool, str]: + def robot_command(self, robot_command: robot_command_pb2.RobotCommand) -> typing.Tuple[bool, str]: end_time = time.time() + MAX_COMMAND_DURATION return self._robot_command( robot_command, @@ -1440,15 +1328,11 @@ def manipulation_command(self, request): timesync_endpoint=self._robot.time_sync.endpoint, ) - def get_robot_command_feedback( - self, cmd_id: int - ) -> robot_command_pb2.RobotCommandFeedbackResponse: + def get_robot_command_feedback(self, cmd_id: int) -> robot_command_pb2.RobotCommandFeedbackResponse: return self._robot_command_client.robot_command_feedback(cmd_id) def get_manipulation_command_feedback(self, cmd_id): - feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest( - manipulation_cmd_id=cmd_id - ) + feedback_request = manipulation_api_pb2.ManipulationApiFeedbackRequest(manipulation_cmd_id=cmd_id) return self._manipulation_api_client.manipulation_api_feedback_command( manipulation_api_feedback_request=feedback_request @@ -1463,13 +1347,8 @@ def toggle_power(self, should_power_on): motors_on = False while not motors_on: future = self._robot_state_client.get_robot_state_async() - state_response = future.result( - timeout=10 - ) # 10 second timeout for waiting for the state response. - if ( - state_response.power_state.motor_power_state - == robot_state_pb2.PowerState.STATE_ON - ): + state_response = future.result(timeout=10) # 10 second timeout for waiting for the state response. + if state_response.power_state.motor_power_state == robot_state_pb2.PowerState.STATE_ON: motors_on = True else: # Motors are not yet fully powered on. @@ -1500,9 +1379,7 @@ def stop_choreography(self) -> typing.Tuple[bool, str]: else: return False, "Spot is not licensed for choreography" - def execute_dance( - self, data: typing.Union[ChoreographySequence, str] - ) -> typing.Tuple[bool, str]: + def execute_dance(self, data: typing.Union[ChoreographySequence, str]) -> typing.Tuple[bool, str]: """Upload and execute dance. Data can be passed as - ChoreographySequence: proto passed directly to function - str: file contents of a .csq read directly from disk @@ -1517,9 +1394,7 @@ def execute_choreography_by_name( ) -> typing.Tuple[bool, str]: """Execute choreography that has already been uploaded to the robot""" if self._is_licensed_for_choreography: - return self._spot_dance.execute_choreography_by_name( - choreography_name, start_slice, use_async - ) + return self._spot_dance.execute_choreography_by_name(choreography_name, start_slice, use_async) else: return False, "Spot is not licensed for choreography" @@ -1532,19 +1407,13 @@ def upload_choreography( else: return False, "Spot is not licensed for choreography" - def upload_animation( - self, animation_name: str, animation_file_content: str - ) -> typing.Tuple[bool, str]: + def upload_animation(self, animation_name: str, animation_file_content: str) -> typing.Tuple[bool, str]: if self._is_licensed_for_choreography: - return self._spot_dance.upload_animation( - animation_name, animation_file_content - ) + return self._spot_dance.upload_animation(animation_name, animation_file_content) else: return False, "Spot is not licensed for choreography" - def upload_animation_proto( - self, animation_proto: Animation - ) -> typing.Tuple[bool, str]: + def upload_animation_proto(self, animation_proto: Animation) -> typing.Tuple[bool, str]: if self._is_licensed_for_choreography: return self._spot_dance.upload_animation_proto(animation_proto) else: @@ -1582,9 +1451,7 @@ def get_docking_state(self, **kwargs): state = self._docking_client.get_docking_state(**kwargs) return state - def start_recording_state( - self, duration_seconds: float - ) -> typing.Tuple[bool, str, StartRecordingStateResponse]: + def start_recording_state(self, duration_seconds: float) -> typing.Tuple[bool, str, StartRecordingStateResponse]: """start recording robot motion as choreography""" if self._is_licensed_for_choreography: return self._spot_dance.start_recording_state(duration_seconds) @@ -1607,8 +1474,6 @@ def choreography_log_to_animation_file( ) -> typing.Tuple[bool, str, str]: """save a choreography log to a file as an animation""" if self._is_licensed_for_choreography: - return self._spot_dance.choreography_log_to_animation_file( - name, fpath, has_arm, **kwargs - ) + return self._spot_dance.choreography_log_to_animation_file(name, fpath, has_arm, **kwargs) else: return False, "Spot is not licensed for choreography", "" diff --git a/spot_wrapper/wrapper_helpers.py b/spot_wrapper/wrapper_helpers.py index bf2f816b..10b411ae 100644 --- a/spot_wrapper/wrapper_helpers.py +++ b/spot_wrapper/wrapper_helpers.py @@ -1,7 +1,7 @@ """Helper classes for the wrapper. This file is necessary to prevent circular imports caused by the modules also using these classes""" -import typing import functools +import typing from dataclasses import dataclass @@ -28,9 +28,7 @@ def __init__( self.claim = claim_function self._get_lease_on_action = get_lease_on_action - def _make_function_take_lease_and_power_on( - self, func: typing.Callable, power_on: bool = True - ) -> typing.Callable: + def _make_function_take_lease_and_power_on(self, func: typing.Callable, power_on: bool = True) -> typing.Callable: """ Decorator which tries to acquire the lease before executing the wrapped function @@ -77,7 +75,7 @@ def make_function_take_lease_and_power_on( if not hasattr(decorated_object, function_name): raise AttributeError( f"Requested decoration of function {function_name} of object {decorated_object}, but the object does " - f"not have a function by that name." + "not have a function by that name." ) setattr( @@ -109,9 +107,7 @@ def decorate_functions( decorated_funcs_no_power = [] for func in decorated_funcs_no_power: - self.make_function_take_lease_and_power_on( - decorated_object, func, power_on=False - ) + self.make_function_take_lease_and_power_on(decorated_object, func, power_on=False) @dataclass()