Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runner/live: Refactor pipeline streamer and add some docs #306

Merged
merged 7 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,30 @@
sys.path.insert(0, infer_root)

from params_api import start_http_server
from streamer.trickle import TrickleStreamer
from streamer.zeromq import ZeroMQStreamer
from streamer.protocol.trickle import TrickleProtocol
from streamer.protocol.zeromq import ZeroMQProtocol


async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish_url: str, control_url: str, pipeline: str, params: dict, input_timeout: int):
if stream_protocol == "trickle":
handler = TrickleStreamer(subscribe_url, publish_url, pipeline, input_timeout, params or {})
protocol = TrickleProtocol(subscribe_url, publish_url)
elif stream_protocol == "zeromq":
handler = ZeroMQStreamer(subscribe_url, publish_url, pipeline, input_timeout, params or {})
protocol = ZeroMQProtocol(subscribe_url, publish_url)
else:
raise ValueError(f"Unsupported protocol: {stream_protocol}")

streamer = PipelineStreamer(protocol, pipeline, input_timeout, params or {})

runner = None
try:
handler.start()
runner = await start_http_server(handler, http_port)
await streamer.start()
runner = await start_http_server(streamer, http_port)

tasks: List[asyncio.Task] = []
tasks.append(handler.wait())
tasks.append(streamer.wait())
tasks.append(asyncio.create_task(block_until_signal([signal.SIGINT, signal.SIGTERM])))
if control_url is not None and control_url.strip() != "":
tasks.append(asyncio.create_task(start_control_subscriber(handler, control_url)))
tasks.append(asyncio.create_task(start_control_subscriber(streamer, control_url)))

await asyncio.wait(tasks,
return_when=asyncio.FIRST_COMPLETED
Expand All @@ -49,7 +51,7 @@ async def main(http_port: int, stream_protocol: str, subscribe_url: str, publish
raise e
finally:
await runner.cleanup()
await handler.stop()
await streamer.stop()


async def block_until_signal(sigs: List[signal.Signals]):
Expand Down Expand Up @@ -82,8 +84,8 @@ async def start_control_subscriber(handler: PipelineStreamer, control_url: str):
except Exception as e:
logging.error(f"Error parsing control message: {e}")
continue
try:

try:
handler.update_params(data)
except Exception as e:
logging.error(f"Error updating model with control message: {e}")
Expand Down
34 changes: 34 additions & 0 deletions runner/app/live/pipelines/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,47 @@
from PIL import Image

class Pipeline(ABC):
"""Abstract base class for image processing pipelines.

Processes frames sequentially and supports dynamic parameter updates.

Notes:
- Methods are only called one at a time in a separate process, so no need
for any locking.
- Error handling is done by the caller, so the implementation can let
exceptions propagate for optimal error reporting.
"""

def __init__(self, **params):
"""Initialize pipeline with optional parameters.

Args:
**params: Parameters to initalize the pipeline with.
"""
pass

@abstractmethod
def process_frame(self, frame: Image.Image) -> Image.Image:
"""Process a single frame through the pipeline.

Called sequentially with each frame from the stream.

Args:
frame: Input PIL Image

Returns:
Processed PIL Image
"""
pass

@abstractmethod
def update_params(self, **params):
"""Update pipeline parameters.

Must maintain valid state on success or restore previous state on failure.
Called sequentially with process_frame so concurrency is not an issue.

Args:
**params: Implementation-specific parameters
"""
pass
10 changes: 1 addition & 9 deletions runner/app/live/pipelines/streamdiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ class Config:
use_lcm_lora: bool = True
lcm_lora_id: str = "latent-consistency/lcm-lora-sdv1-5"
num_inference_steps: int = 50
t_index_list: Optional[List[int]] = None
t_index_ratio_list: Optional[List[float]] = [0.75, 0.9, 0.975]
t_index_list: Optional[List[int]] = [37, 45, 48]
scale: float = 1.0
acceleration: Literal["none", "xformers", "tensorrt"] = "tensorrt"
use_denoising_batch: bool = True
Expand All @@ -29,13 +28,6 @@ class Config:
do_add_noise: bool = False
similar_image_filter_threshold: float = 0.98

def __init__(self, **data):
super().__init__(**data)
if self.t_index_ratio_list is not None and self.t_index_list is None:
self.t_index_list = [
int(i * self.num_inference_steps) for i in self.t_index_ratio_list
]


class StreamDiffusion(Pipeline):
def __init__(self, **params):
Expand Down
2 changes: 1 addition & 1 deletion runner/app/live/streamer/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,4 @@ def process_loop(self):
except Exception as e:
logging.error(f"Error processing frame: {e}")
except Exception as e:
logging.error(f"Error in process run method: {e}")
logging.error(f"Error in process run method: {e}", exc_info=True)
27 changes: 27 additions & 0 deletions runner/app/live/streamer/protocol/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from multiprocessing.synchronize import Event
from PIL import Image

class StreamProtocol(ABC):
@abstractmethod
async def start(self):
"""Initialize and start the streaming protocol"""
pass

@abstractmethod
async def stop(self):
"""Clean up and stop the streaming protocol"""
pass

@abstractmethod
async def ingress_loop(self, done: Event) -> AsyncGenerator[Image.Image, None]:
"""Generator that yields the ingress frames"""
if False:
yield Image.new('RGB', (1, 1)) # dummy yield for type checking
pass

@abstractmethod
async def egress_loop(self, output_frames: AsyncGenerator[Image.Image, None]):
"""Consumes generated frames and processes them"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,28 @@

from trickle import media

from .streamer import PipelineStreamer
from .protocol import StreamProtocol
from .jpeg import to_jpeg_bytes, from_jpeg_bytes

class TrickleStreamer(PipelineStreamer):
def __init__(
self,
subscribe_url: str,
publish_url: str,
pipeline: str,
input_timeout: int,
params: dict,
):
super().__init__(pipeline, input_timeout, params)
class TrickleProtocol(StreamProtocol):
def __init__(self, subscribe_url: str, publish_url: str):
self.subscribe_url = subscribe_url
self.publish_url = publish_url
self.subscribe_queue = queue.Queue[bytearray]()
self.publish_queue = queue.Queue[bytearray]()
self.subscribe_task = None
self.publish_task = None

def start(self):
self.subscribe_task = asyncio.create_task(media.run_subscribe(self.subscribe_url, self.subscribe_queue.put))
self.publish_task = asyncio.create_task(media.run_publish(self.publish_url, self.publish_queue.get))
super().start()
async def start(self):
self.subscribe_task = asyncio.create_task(
media.run_subscribe(self.subscribe_url, self.subscribe_queue.put)
)
self.publish_task = asyncio.create_task(
media.run_publish(self.publish_url, self.publish_queue.get)
)

async def stop(self):
if not self.subscribe_task or not self.publish_task:
await super().stop()
return

# send sentinel None values to stop the trickle tasks
Expand All @@ -51,8 +47,6 @@ async def stop(self):
self.subscribe_task = None
self.publish_task = None

await super().stop()

async def ingress_loop(self, done: Event) -> AsyncGenerator[Image.Image, None]:
def dequeue_jpeg():
jpeg_bytes = self.subscribe_queue.get()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,26 @@
from multiprocessing.synchronize import Event
from typing import AsyncGenerator

from .streamer import PipelineStreamer
from .protocol import StreamProtocol
from .jpeg import to_jpeg_bytes, from_jpeg_bytes


class ZeroMQStreamer(PipelineStreamer):
def __init__(
self,
input_address: str,
output_address: str,
pipeline: str,
input_timeout: int,
params: dict,
):
super().__init__(pipeline, input_timeout, params)
class ZeroMQProtocol(StreamProtocol):
def __init__(self, input_address: str, output_address: str):
self.input_address = input_address
self.output_address = output_address

self.context = zmq.asyncio.Context()
self.input_socket = self.context.socket(zmq.SUB)
self.output_socket = self.context.socket(zmq.PUB)

def start(self):
async def start(self):
self.input_socket.connect(self.input_address)
self.input_socket.setsockopt_string(
zmq.SUBSCRIBE, ""
) # Subscribe to all messages
self.input_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.input_socket.set_hwm(10)

self.output_socket.connect(self.output_address)
self.output_socket.set_hwm(10)

super().start()

async def stop(self):
await super().stop()
self.input_socket.close()
self.output_socket.close()
self.context.term()
Expand Down
29 changes: 10 additions & 19 deletions runner/app/live/streamer/streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,20 @@
import os
import time
import traceback
from abc import ABC, abstractmethod
from multiprocessing.synchronize import Event
from typing import AsyncGenerator

from PIL import Image

from .process import PipelineProcess
from .protocol.protocol import StreamProtocol

fps_log_interval = 10


class PipelineStreamer(ABC):
def __init__(self, pipeline: str, input_timeout: int, params: dict):
class PipelineStreamer:
def __init__(self, protocol: StreamProtocol, pipeline: str, input_timeout: int, params: dict):
self.protocol = protocol
self.pipeline = pipeline
self.params = params
self.process = None
Expand All @@ -24,16 +25,18 @@ def __init__(self, pipeline: str, input_timeout: int, params: dict):
self.input_timeout = input_timeout # 0 means disabled
self.done_future = None

def start(self):
async def start(self):
self.done_future = asyncio.get_running_loop().create_future()
self._start_process()
await self.protocol.start()

async def wait(self):
if not self.done_future:
raise RuntimeError("Streamer not started")
return await self.done_future

async def stop(self):
await self.protocol.stop()
await self._stop_process()
if self.done_future and not self.done_future.done():
self.done_future.set_result(None)
Expand Down Expand Up @@ -66,7 +69,7 @@ async def _stop_process(self):

async def _restart(self):
try:
# don't call the start/stop methods since those might be overridden by the concrete implementations
# don't call the full start/stop methods since we don't want to restart the protocol
await self._stop_process()
self._start_process()
self.restart_count += 1
Expand Down Expand Up @@ -133,7 +136,7 @@ async def run_ingress_loop(self, done: Event):
frame_count = 0
start_time = time.time()
try:
async for frame in self.ingress_loop(done):
async for frame in self.protocol.ingress_loop(done):
if done.is_set() or not self.process:
return

Expand Down Expand Up @@ -187,21 +190,9 @@ async def gen_output_frames() -> AsyncGenerator[Image.Image, None]:
start_time = time.time()

try:
await self.egress_loop(gen_output_frames())
await self.protocol.egress_loop(gen_output_frames())
# automatically stop the streamer when the egress ends cleanly
await self.stop()
except Exception:
logging.error("Error running egress loop", exc_info=True)
await self._restart()

@abstractmethod
async def ingress_loop(self, done: Event) -> AsyncGenerator[Image.Image, None]:
"""Generator that yields the ingress frames."""
if False:
yield Image.new('RGB', (1, 1)) # dummy yield for linter to see this is a generator
pass

@abstractmethod
async def egress_loop(self, output_frames: AsyncGenerator[Image.Image, None]):
"""Consumes generated frames and processes them."""
pass
Loading