Skip to content

Commit

Permalink
(refactor) Changes handling of image messages for publication
Browse files Browse the repository at this point in the history
- Adds support for CompressedImage messages
- Gathers image messages directly in vision component instead of getting them back from clients
  • Loading branch information
aleph-ra committed Nov 22, 2024
1 parent bca5adf commit a4b36ce
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 22 deletions.
4 changes: 3 additions & 1 deletion agents/agents/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
get_logger,
)

from ros_sugar.io.utils import image_pre_processing
from ros_sugar.io.utils import image_pre_processing, read_compressed_image

from .utils import create_detection_context

Expand Down Expand Up @@ -70,6 +70,8 @@ def _get_output(self, **_) -> Optional[np.ndarray]:
video = []
for img in self.msg.frames:
video.append(image_pre_processing(img))
for img in self.msg.compressed_frames:
video.append(read_compressed_image(img))
return np.array(video)


Expand Down
6 changes: 1 addition & 5 deletions agents/agents/clients/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _inference(self, inference_input: Dict[str, Any]) -> Optional[Dict]:
}
inference_input.pop("query")

# make images parth of the latest message in message list
# make images part of the latest message in message list
if images := inference_input.get("images"):
input["messages"][-1]["images"] = [encode_arr_base64(img) for img in images]
inference_input.pop("images")
Expand All @@ -126,10 +126,6 @@ def _inference(self, inference_input: Dict[str, Any]) -> Optional[Dict]:

self.logger.debug(str(ollama_result))

# Add np images back in inference input
if images:
input["images"] = images

# make result part of the input
input["output"] = ollama_result["message"]["content"] # type: ignore

Expand Down
8 changes: 0 additions & 8 deletions agents/agents/clients/roboml.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ def _inference(self, inference_input: Dict[str, Any]) -> Optional[Dict]:

self.logger.debug(str(result))

# replace np images back in inference input
if images:
inference_input["images"] = images

# make input query part of the result
result.update(inference_input)
return result

def _deinitialize(self) -> None:
Expand Down Expand Up @@ -463,8 +457,6 @@ def _inference(self, inference_input: Dict[str, Any]) -> Optional[Dict]:

self.logger.debug(str(result))

# make input query part of the result
result.update(inference_input)
return result

def _deinitialize(self) -> None:
Expand Down
20 changes: 18 additions & 2 deletions agents/agents/components/vision.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Any, Union, Optional, List, Dict
import numpy as np

from ..clients.model_base import ModelClient
from ..config import VisionConfig
from ..ros import Detections, FixedInput, Image, Topic, Trackings
from ..ros import (
Detections,
FixedInput,
Image,
Topic,
Trackings,
ROSImage,
ROSCompressedImage,
)
from ..utils import validate_func_args
from .model_component import ModelComponent
from .component_base import ComponentRunType
Expand Down Expand Up @@ -66,6 +75,8 @@ def __init__(
self.allowed_inputs = {"Required": [Image]}
self.handled_outputs = [Detections, Trackings]

self._images: List[Union[np.ndarray, ROSImage, ROSCompressedImage]] = []

super().__init__(
inputs,
outputs,
Expand All @@ -90,15 +101,20 @@ def _create_input(self, *_, **kwargs) -> Optional[Dict[str, Any]]:
:param kwargs:
:rtype: dict[str, Any]
"""
self._images = []
# set one image topic as query for event based trigger
if trigger := kwargs.get("topic"):
images = [self.trig_callbacks[trigger.name].get_output()]
if msg := kwargs.get("msg"):
self._images.append(msg)
else:
images = []

for i in self.callbacks.values():
if (item := i.get_output()) is not None:
images.append(item)
if i.msg:
self._images.append(i.msg) # Collect all images for publishing

if not images:
return None
Expand Down Expand Up @@ -140,6 +156,6 @@ def _execution_step(self, *args, **kwargs):
for publisher in self.publishers_dict.values():
publisher.publish(
**result,
frame_id=self.trig_callbacks[trigger.name].frame_id,
images=self._images,
time_stamp=self.get_ros_time(),
)
36 changes: 30 additions & 6 deletions agents/agents/ros.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
SupportedType,
Audio,
Image,
CompressedImage,
OccupancyGrid,
Odometry,
String,
ROSImage,
ROSCompressedImage,
)
from ros_sugar.io import (
get_all_msg_types,
Expand Down Expand Up @@ -44,6 +46,7 @@
"String",
"Audio",
"Image",
"CompressedImage",
"OccupancyGrid",
"Odometry",
"Topic",
Expand Down Expand Up @@ -76,14 +79,25 @@ class Video(SupportedType):
callback = VideoCallback

@classmethod
def convert(cls, output: Union[List[ROSImage], List[np.ndarray]], **_) -> ROSVideo:
def convert(
cls,
output: Union[List[ROSImage], List[ROSCompressedImage], List[np.ndarray]],
**_,
) -> ROSVideo:
"""
Takes an list of images and returns a video message (Image Array)
:return: Video
"""
frames = [Image.convert(frame) for frame in output]
msg = ROSVideo()
frames = []
compressed_frames = []
for frame in output:
if isinstance(frame, ROSCompressedImage):
compressed_frames.append(CompressedImage.convert(frame))
else:
frames.append(Image.convert(frame))
msg.frames = frames
msg.compressed_frames = compressed_frames
return msg


Expand All @@ -94,7 +108,9 @@ class Detection(SupportedType):
callback = None # not defined

@classmethod
def convert(cls, output: Dict, img: np.ndarray, **_) -> Detection2D:
def convert(
cls, output: Dict, img: Union[ROSImage, ROSCompressedImage, np.ndarray], **_
) -> Detection2D:
"""
Takes object detection data and converts it into a ROS message
of type Detection2D
Expand All @@ -113,7 +129,10 @@ def convert(cls, output: Dict, img: np.ndarray, **_) -> Detection2D:
boxes.append(box)

msg.boxes = boxes
msg.image = Image.convert(img)
if isinstance(img, ROSCompressedImage):
msg.compressed_image = CompressedImage.convert(img)
else:
msg.image = Image.convert(img)
return msg


Expand Down Expand Up @@ -145,7 +164,9 @@ class Tracking(SupportedType):
callback = None # Not defined

@classmethod
def convert(cls, output: Dict, img: np.ndarray, **_) -> ROSTracking:
def convert(
cls, output: Dict, img: Union[ROSImage, ROSCompressedImage, np.ndarray], **_
) -> ROSTracking:
"""
Takes tracking data and converts it into a ROS message
of type Tracking
Expand Down Expand Up @@ -183,7 +204,10 @@ def convert(cls, output: Dict, img: np.ndarray, **_) -> ROSTracking:
msg.boxes = tracked_boxes
msg.centroids = centroids
msg.estimated_velocities = estimated_velocities
msg.image = Image.convert(img)
if isinstance(img, ROSCompressedImage):
msg.compressed_image = CompressedImage.convert(img)
else:
msg.image = Image.convert(img)
return msg


Expand Down
2 changes: 2 additions & 0 deletions agents_interfaces/msg/Detection2D.msg
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ float64[] scores
string[] labels
Bbox2D[] boxes

# Either an image or compressed image
sensor_msgs/Image image
sensor_msgs/CompressedImage compressed_image
2 changes: 2 additions & 0 deletions agents_interfaces/msg/Tracking.msg
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ Bbox2D[] boxes
int8[] ids
Point2D[] estimated_velocities

# Either an image or compressed image
sensor_msgs/Image image
sensor_msgs/CompressedImage compressed_image
2 changes: 2 additions & 0 deletions agents_interfaces/msg/Video.msg
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
std_msgs/Header header

# Eithen a list of images or compressed images
sensor_msgs/Image[] frames
sensor_msgs/CompressedImage[] compressed_frames

0 comments on commit a4b36ce

Please sign in to comment.