From 7dd9f488061092b440ed27047360d2f0a09d9028 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 09:28:33 -0500 Subject: [PATCH 1/6] Make FlowManager args more explicit --- src/pipecat_flows/adapters.py | 8 +++++++- src/pipecat_flows/manager.py | 9 +++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/pipecat_flows/adapters.py b/src/pipecat_flows/adapters.py index 19b06eb..e024e01 100644 --- a/src/pipecat_flows/adapters.py +++ b/src/pipecat_flows/adapters.py @@ -158,4 +158,10 @@ def create_adapter(llm) -> LLMAdapter: return AnthropicAdapter() elif isinstance(llm, GoogleLLMService): return GeminiAdapter() - raise ValueError(f"Unsupported LLM type: {type(llm)}") + raise ValueError( + f"Unsupported LLM type: {type(llm)}\n" + "Must provide one of:\n" + "- OpenAILLMService\n" + "- AnthropicLLMService\n" + "- GoogleLLMService" + ) diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index 949e8ce..e40006e 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -95,6 +95,7 @@ async def handle_transitions(function_name: str, args: Dict, flow_manager): def __init__( self, + *, task: PipelineTask, llm: Union[OpenAILLMService, AnthropicLLMService, GoogleLLMService], tts: Optional[Any] = None, @@ -114,6 +115,14 @@ def __init__( transition_callback: Optional callback for handling transitions. Required for dynamic flows, ignored for static flows in favor of static transitions + + Example: + flow_manager = FlowManager( + task=pipeline_task, + llm=openai_service, + tts=tts_service, + flow_config=config + ) """ self.task = task self.llm = llm From bf73c891a6786b2e8eb501d379bada5b8c6ecfe2 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 09:33:34 -0500 Subject: [PATCH 2/6] Update changelog, add docstring formatting, bump to version 0.0.8 --- CHANGELOG.md | 7 +++++++ pyproject.toml | 11 +++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e3e155f..9a41ded 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to **Pipecat Flows** will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.0.8] - 2024-12-07 + +### Changed + +- Improved type safety in FlowManager by requiring keyword arguments for initialization +- Enhanced error messages for LLM service type validation + ## [0.0.7] - 2024-12-06 ### Added diff --git a/pyproject.toml b/pyproject.toml index 641840c..90dfb8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pipecat-ai-flows" -version = "0.0.7" +version = "0.0.8" description = "Conversation Flow management for Pipecat AI applications" license = { text = "BSD 2-Clause License" } readme = "README.md" @@ -33,4 +33,11 @@ testpaths = ["tests"] asyncio_mode = "auto" [tool.ruff] -line-length = 100 \ No newline at end of file +line-length = 100 + +select = [ + "D", # Docstring rules +] + +[tool.ruff.pydocstyle] +convention = "google" \ No newline at end of file From bb5dc08cd21c3fda1d272d37cc4c8ad3c78dc7b5 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 09:53:46 -0500 Subject: [PATCH 3/6] Improve docstrings --- pyproject.toml | 1 + src/pipecat_flows/__init__.py | 2 +- src/pipecat_flows/actions.py | 17 ++++ src/pipecat_flows/adapters.py | 141 ++++++++++++++++++++++++++++++-- src/pipecat_flows/exceptions.py | 13 +++ src/pipecat_flows/manager.py | 19 +++++ src/pipecat_flows/types.py | 12 +++ tests/test_actions.py | 69 +++++++++++++--- tests/test_adapters.py | 72 +++++++++++++--- tests/test_manager.py | 48 ++++++++++- 10 files changed, 358 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 90dfb8e..fedecef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ line-length = 100 select = [ "D", # Docstring rules ] +ignore = ["D212"] [tool.ruff.pydocstyle] convention = "google" \ No newline at end of file diff --git a/src/pipecat_flows/__init__.py b/src/pipecat_flows/__init__.py index d467a5a..27b0f7f 100644 --- a/src/pipecat_flows/__init__.py +++ b/src/pipecat_flows/__init__.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: BSD 2-Clause License # """ -Pipecat Flows +Pipecat Flows. This package provides a framework for building structured conversations in Pipecat. The FlowManager can handle both static and dynamic conversation flows: diff --git a/src/pipecat_flows/actions.py b/src/pipecat_flows/actions.py index a3565c3..0969da0 100644 --- a/src/pipecat_flows/actions.py +++ b/src/pipecat_flows/actions.py @@ -4,6 +4,23 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Action management system for conversation flows. + +This module provides the ActionManager class which handles execution of actions +during conversation state transitions. It supports: +- Built-in actions (TTS, conversation ending) +- Custom action registration +- Synchronous and asynchronous handlers +- Pre and post-transition actions +- Error handling and validation + +Actions are used to perform side effects during conversations, such as: +- Text-to-speech output +- Database updates +- External API calls +- Custom integrations +""" + import asyncio from typing import Any, Callable, Dict, List, Optional diff --git a/src/pipecat_flows/adapters.py b/src/pipecat_flows/adapters.py index e024e01..0848ac0 100644 --- a/src/pipecat_flows/adapters.py +++ b/src/pipecat_flows/adapters.py @@ -4,6 +4,20 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""LLM provider adapters for normalizing function and message formats. + +This module provides adapters that normalize interactions between different +LLM providers (OpenAI, Anthropic, Gemini). It handles: +- Function name extraction +- Argument parsing +- Message content formatting +- Provider-specific schema conversion + +The adapter system allows the flow manager to work with different LLM +providers while maintaining a consistent internal format (based on OpenAI's +function calling convention). +""" + from abc import ABC, abstractmethod from typing import Any, Dict, List @@ -48,36 +62,112 @@ def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, An class OpenAIAdapter(LLMAdapter): - """Format adapter for OpenAI.""" + """Format adapter for OpenAI. + + Handles OpenAI's function calling format, which is used as the default format + in the flow system. + """ def get_function_name(self, function_def: Dict[str, Any]) -> str: + """Extract function name from OpenAI function definition. + + Args: + function_def: OpenAI-formatted function definition dictionary + + Returns: + Function name from the definition + """ return function_def["function"]["name"] def get_function_args(self, function_call: Dict[str, Any]) -> dict: + """Extract arguments from OpenAI function call. + + Args: + function_call: OpenAI-formatted function call dictionary + + Returns: + Dictionary of function arguments, empty if none provided + """ return function_call.get("arguments", {}) def get_message_content(self, message: Dict[str, Any]) -> str: + """Extract content from OpenAI message format. + + Args: + message: OpenAI-formatted message dictionary + + Returns: + Message content as string + """ return message["content"] def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - return functions # OpenAI format is our default + """Format functions for OpenAI use. + + Args: + functions: List of function definitions + + Returns: + Functions in OpenAI format (unchanged as this is our default format) + """ + return functions class AnthropicAdapter(LLMAdapter): - """Format adapter for Anthropic.""" + """Format adapter for Anthropic. + + Handles Anthropic's native function format, converting between OpenAI's format + and Anthropic's as needed. + """ def get_function_name(self, function_def: Dict[str, Any]) -> str: + """Extract function name from Anthropic function definition. + + Args: + function_def: Anthropic-formatted function definition dictionary + + Returns: + Function name from the definition + """ return function_def["name"] def get_function_args(self, function_call: Dict[str, Any]) -> dict: + """Extract arguments from Anthropic function call. + + Args: + function_call: Anthropic-formatted function call dictionary + + Returns: + Dictionary of function arguments, empty if none provided + """ return function_call.get("arguments", {}) def get_message_content(self, message: Dict[str, Any]) -> str: + """Extract content from Anthropic message format. + + Handles both string content and structured content arrays. + + Args: + message: Anthropic-formatted message dictionary + + Returns: + Message content as string, concatenated if from multiple parts + """ if isinstance(message.get("content"), list): return " ".join(item["text"] for item in message["content"] if item["type"] == "text") return message.get("content", "") def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Format functions for Anthropic use. + + Converts from OpenAI format to Anthropic's native function format if needed. + + Args: + functions: List of function definitions in OpenAI format + + Returns: + Functions converted to Anthropic's format + """ formatted = [] for func in functions: if "function" in func: @@ -96,28 +186,61 @@ def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, An class GeminiAdapter(LLMAdapter): - """Format adapter for Google's Gemini.""" + """Format adapter for Google's Gemini. + + Handles Gemini's function declarations format, converting between OpenAI's format + and Gemini's as needed. + """ def get_function_name(self, function_def: Dict[str, Any]) -> str: - """Extract function name from provider-specific function definition.""" + """Extract function name from Gemini function definition. + + Args: + function_def: Gemini-formatted function definition dictionary + + Returns: + Function name from the first declaration, or empty string if none found + """ logger.debug(f"Getting function name from: {function_def}") if "function_declarations" in function_def: declarations = function_def["function_declarations"] if declarations and isinstance(declarations, list): - # Return name of current function being processed return declarations[0]["name"] return "" def get_function_args(self, function_call: Dict[str, Any]) -> dict: - """Extract function arguments from provider-specific function call.""" + """Extract arguments from Gemini function call. + + Args: + function_call: Gemini-formatted function call dictionary + + Returns: + Dictionary of function arguments, empty if none provided + """ return function_call.get("args", {}) def get_message_content(self, message: Dict[str, Any]) -> str: - """Extract message content from provider-specific format.""" + """Extract content from Gemini message format. + + Args: + message: Gemini-formatted message dictionary + + Returns: + Message content as string + """ return message["content"] def format_functions(self, functions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Format functions for provider-specific use.""" + """Format functions for Gemini use. + + Converts from OpenAI format to Gemini's function declarations format. + + Args: + functions: List of function definitions in OpenAI format + + Returns: + Functions converted to Gemini's format with declarations wrapper + """ all_declarations = [] for func in functions: if "function_declarations" in func: diff --git a/src/pipecat_flows/exceptions.py b/src/pipecat_flows/exceptions.py index cc5d831..fff7fa9 100644 --- a/src/pipecat_flows/exceptions.py +++ b/src/pipecat_flows/exceptions.py @@ -4,6 +4,19 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Custom exceptions for the conversation flow system. + +This module defines the exception hierarchy used throughout the flow system: +- FlowError: Base exception for all flow-related errors +- FlowInitializationError: Initialization failures +- FlowTransitionError: State transition issues +- InvalidFunctionError: Function registration/calling problems +- ActionError: Action execution failures + +These exceptions provide specific error types for better error handling +and debugging. +""" + class FlowError(Exception): """Base exception for all flow-related errors.""" diff --git a/src/pipecat_flows/manager.py b/src/pipecat_flows/manager.py index e40006e..c70b12a 100644 --- a/src/pipecat_flows/manager.py +++ b/src/pipecat_flows/manager.py @@ -4,6 +4,25 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Core conversation flow management system. + +This module provides the FlowManager class which orchestrates conversations +across different LLM providers. It supports: +- Static flows with predefined paths +- Dynamic flows with runtime-determined transitions +- State management and transitions +- Function registration and execution +- Action handling +- Cross-provider compatibility + +The flow manager coordinates all aspects of a conversation, including: +- LLM context management +- Function registration +- State transitions +- Action execution +- Error handling +""" + import copy import inspect from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union diff --git a/src/pipecat_flows/types.py b/src/pipecat_flows/types.py index 8246eb6..d243b3e 100644 --- a/src/pipecat_flows/types.py +++ b/src/pipecat_flows/types.py @@ -4,6 +4,18 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Type definitions for the conversation flow system. + +This module defines the core types used throughout the flow system: +- FlowResult: Function return type +- FlowArgs: Function argument type +- NodeConfig: Node configuration type +- FlowConfig: Complete flow configuration type + +These types provide structure and validation for flow configurations +and function interactions. +""" + from typing import Any, Dict, List, TypedDict diff --git a/tests/test_actions.py b/tests/test_actions.py index 8ef75fa..0eefb89 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -1,3 +1,24 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Test suite for ActionManager functionality. + +This module tests the ActionManager class which handles execution of actions +during conversation flows. Tests cover: +- Built-in actions (TTS, end conversation) +- Custom action registration and execution +- Error handling and validation +- Action sequencing +- TTS service integration +- Frame queueing + +The tests use unittest.IsolatedAsyncioTestCase for async support and include +mocked dependencies for PipelineTask and TTS service. +""" + import unittest from unittest.mock import AsyncMock, patch @@ -8,6 +29,28 @@ class TestActionManager(unittest.IsolatedAsyncioTestCase): + """Test suite for ActionManager class. + + Tests functionality of ActionManager including: + - Built-in action handlers: + - TTS speech synthesis + - Conversation ending + - Custom action registration + - Action execution sequencing + - Error handling: + - Missing TTS service + - Invalid actions + - Failed handlers + - Multiple action execution + - Frame queueing validation + + Each test uses mocked dependencies to verify: + - Correct frame generation + - Proper service calls + - Error handling behavior + - Action sequencing + """ + def setUp(self): """ Set up test fixtures before each test. @@ -26,7 +69,7 @@ def setUp(self): self.action_manager = ActionManager(self.mock_task, self.mock_tts) async def test_initialization(self): - """Test ActionManager initialization and default handlers""" + """Test ActionManager initialization and default handlers.""" # Verify built-in action handlers are registered self.assertIn("tts_say", self.action_manager.action_handlers) self.assertIn("end_conversation", self.action_manager.action_handlers) @@ -36,7 +79,7 @@ async def test_initialization(self): self.assertIsNone(action_manager_no_tts.tts) async def test_tts_action(self): - """Test basic TTS action execution""" + """Test basic TTS action execution.""" action = {"type": "tts_say", "text": "Hello"} await self.action_manager.execute_actions([action]) @@ -45,7 +88,7 @@ async def test_tts_action(self): @patch("loguru.logger.error") async def test_tts_action_no_text(self, mock_logger): - """Test TTS action with missing text field""" + """Test TTS action with missing text field.""" action = {"type": "tts_say"} # Missing text field # The implementation logs error but doesn't raise @@ -59,7 +102,7 @@ async def test_tts_action_no_text(self, mock_logger): @patch("loguru.logger.warning") async def test_tts_action_no_service(self, mock_logger): - """Test TTS action when no TTS service is provided""" + """Test TTS action when no TTS service is provided.""" action_manager = ActionManager(self.mock_task, None) action = {"type": "tts_say", "text": "Hello"} @@ -73,7 +116,7 @@ async def test_tts_action_no_service(self, mock_logger): self.mock_task.queue_frame.assert_not_called() async def test_end_conversation_action(self): - """Test basic end conversation action""" + """Test basic end conversation action.""" action = {"type": "end_conversation"} await self.action_manager.execute_actions([action]) @@ -83,7 +126,7 @@ async def test_end_conversation_action(self): self.assertIsInstance(frame, EndFrame) async def test_end_conversation_with_goodbye(self): - """Test end conversation action with goodbye message""" + """Test end conversation action with goodbye message.""" action = {"type": "end_conversation", "text": "Goodbye!"} await self.action_manager.execute_actions([action]) @@ -100,7 +143,7 @@ async def test_end_conversation_with_goodbye(self): self.assertIsInstance(second_frame, EndFrame) async def test_custom_action(self): - """Test registering and executing custom actions""" + """Test registering and executing custom actions.""" mock_handler = AsyncMock() self.action_manager._register_action("custom", mock_handler) @@ -115,7 +158,7 @@ async def test_custom_action(self): mock_handler.assert_called_once_with(action) async def test_invalid_action(self): - """Test handling invalid actions""" + """Test handling invalid actions.""" # Test missing type with self.assertRaises(ActionError) as context: await self.action_manager.execute_actions([{}]) @@ -127,7 +170,7 @@ async def test_invalid_action(self): self.assertIn("No handler registered", str(context.exception)) async def test_multiple_actions(self): - """Test executing multiple actions in sequence""" + """Test executing multiple actions in sequence.""" actions = [ {"type": "tts_say", "text": "First"}, {"type": "tts_say", "text": "Second"}, @@ -140,7 +183,7 @@ async def test_multiple_actions(self): self.assertEqual(self.mock_tts.say.call_args_list, expected_calls) def test_register_invalid_handler(self): - """Test registering invalid action handlers""" + """Test registering invalid action handlers.""" # Test non-callable handler with self.assertRaises(ValueError) as context: self.action_manager._register_action("invalid", "not_callable") @@ -152,7 +195,7 @@ def test_register_invalid_handler(self): self.assertIn("must be callable", str(context.exception)) async def test_none_or_empty_actions(self): - """Test handling None or empty action lists""" + """Test handling None or empty action lists.""" # Test None actions await self.action_manager.execute_actions(None) self.mock_task.queue_frame.assert_not_called() @@ -165,7 +208,7 @@ async def test_none_or_empty_actions(self): @patch("loguru.logger.error") async def test_action_error_handling(self, mock_logger): - """Test error handling during action execution""" + """Test error handling during action execution.""" # Configure TTS mock to raise an error self.mock_tts.say.side_effect = Exception("TTS error") @@ -179,7 +222,7 @@ async def test_action_error_handling(self, mock_logger): self.mock_tts.say.assert_called_once() async def test_action_execution_error_handling(self): - """Test error handling during action execution""" + """Test error handling during action execution.""" action_manager = ActionManager(self.mock_task, self.mock_tts) # Test action with missing handler diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 101bba1..d848e89 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -1,3 +1,33 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Test suite for LLM adapter implementations. + +This module tests the adapter system that normalizes interactions between +different LLM providers (OpenAI, Anthropic, Gemini). Tests cover: +- Abstract adapter interface enforcement +- Provider-specific format handling +- Function name and argument extraction +- Message content processing +- Schema validation +- Error cases and edge conditions + +Each adapter is tested with its respective provider's format: +- OpenAI: Function calling format +- Anthropic: Native function format +- Gemini: Function declarations format + +The tests use unittest and include comprehensive validation of: +- Format conversions +- Null/empty value handling +- Special character processing +- Schema validation +- Factory pattern implementation +""" + import unittest from unittest.mock import MagicMock @@ -15,10 +45,10 @@ class TestLLMAdapter(unittest.TestCase): - """Test the abstract base LLMAdapter class""" + """Test the abstract base LLMAdapter class.""" def test_abstract_methods(self): - """Verify that LLMAdapter cannot be instantiated without implementing all methods""" + """Verify that LLMAdapter cannot be instantiated without implementing all methods.""" class IncompleteAdapter(LLMAdapter): # Missing implementation of abstract methods @@ -38,8 +68,28 @@ def get_function_name(self, function_def): class TestLLMAdapters(unittest.TestCase): + """Test suite for concrete LLM adapter implementations. + + Tests adapter functionality for each LLM provider: + - OpenAI: Function calling format + - Anthropic: Native function format + - Gemini: Function declarations format + + Each adapter is tested for: + - Function name extraction + - Argument parsing + - Message content handling + - Format conversion + - Special character handling + - Null/empty value processing + - Schema validation + + The setUp method provides standardized test fixtures for each provider's format, + allowing consistent testing across all adapters. + """ + def setUp(self): - """Set up test cases with sample function definitions for each provider""" + """Set up test cases with sample function definitions for each provider.""" # OpenAI format self.openai_function = { "type": "function", @@ -87,7 +137,7 @@ def setUp(self): self.gemini_message = {"role": "user", "content": "Test message"} def test_openai_adapter(self): - """Test OpenAI format handling""" + """Test OpenAI format handling.""" adapter = OpenAIAdapter() # Test function name extraction @@ -109,7 +159,7 @@ def test_openai_adapter(self): self.assertEqual(formatted, [self.openai_function]) def test_anthropic_adapter(self): - """Test Anthropic format handling""" + """Test Anthropic format handling.""" adapter = AnthropicAdapter() # Test function name extraction @@ -129,7 +179,7 @@ def test_anthropic_adapter(self): self.assertEqual(formatted[0]["name"], "test_function") def test_gemini_adapter(self): - """Test Gemini format handling""" + """Test Gemini format handling.""" adapter = GeminiAdapter() # Test function name extraction from function declarations @@ -149,7 +199,7 @@ def test_gemini_adapter(self): self.assertTrue("function_declarations" in formatted[0]) def test_adapter_factory(self): - """Test adapter creation based on LLM service type""" + """Test adapter creation based on LLM service type.""" # Test with valid LLM services openai_llm = MagicMock(spec=OpenAILLMService) self.assertIsInstance(create_adapter(openai_llm), OpenAIAdapter) @@ -161,7 +211,7 @@ def test_adapter_factory(self): self.assertIsInstance(create_adapter(gemini_llm), GeminiAdapter) def test_adapter_factory_error_cases(self): - """Test error cases in adapter creation""" + """Test error cases in adapter creation.""" # Test with None with self.assertRaises(ValueError) as context: create_adapter(None) @@ -174,7 +224,7 @@ def test_adapter_factory_error_cases(self): self.assertIn("Unsupported LLM type", str(context.exception)) def test_null_and_empty_values(self): - """Test handling of null and empty values""" + """Test handling of null and empty values.""" adapters = [OpenAIAdapter(), AnthropicAdapter(), GeminiAdapter()] for adapter in adapters: @@ -187,7 +237,7 @@ def test_null_and_empty_values(self): self.assertEqual(adapter.get_message_content(empty_message), "") def test_special_characters_handling(self): - """Test handling of special characters in messages and function calls""" + """Test handling of special characters in messages and function calls.""" special_chars = "!@#$%^&*()_+-=[]{}|;:'\",.<>?/~`" # Test in message content @@ -223,7 +273,7 @@ def test_special_characters_handling(self): self.assertEqual(args["param1"], special_chars) def test_function_schema_validation(self): - """Test validation of function schemas during conversion""" + """Test validation of function schemas during conversion.""" adapters = [OpenAIAdapter(), AnthropicAdapter(), GeminiAdapter()] # Test with minimal valid schema diff --git a/tests/test_manager.py b/tests/test_manager.py index 6b1f6c5..ffd45e4 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -1,3 +1,23 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Test suite for FlowManager functionality. + +This module contains tests for the FlowManager class, which handles conversation +flow management across different LLM providers. Tests cover: +- Static and dynamic flow initialization +- State transitions and validation +- Function registration and execution +- Action handling +- Error cases + +The tests use unittest.IsolatedAsyncioTestCase for async support and +include mocked dependencies for PipelineTask, LLM services, and TTS. +""" + import unittest from unittest.mock import AsyncMock, MagicMock, patch @@ -9,8 +29,26 @@ class TestFlowManager(unittest.IsolatedAsyncioTestCase): + """Test suite for FlowManager class. + + Tests functionality of FlowManager including: + - Static and dynamic flow initialization + - State transitions + - Function registration + - Action execution + - Error handling + - Node validation + """ + async def asyncSetUp(self): - """Set up test fixtures.""" + """Set up test fixtures before each test. + + Creates: + - Mock PipelineTask for frame queueing + - Mock LLM service (OpenAI) + - Mock TTS service + - Sample node and flow configurations + """ self.mock_task = AsyncMock() self.mock_llm = MagicMock(spec=OpenAILLMService) self.mock_tts = AsyncMock() @@ -41,7 +79,13 @@ async def asyncSetUp(self): } async def test_static_flow_initialization(self): - """Test initialization of static flow.""" + """Test initialization of a static flow configuration. + + Verifies: + - Correct setup of static mode attributes + - Proper initialization of flow + - Message queueing to task + """ flow_manager = FlowManager( self.mock_task, self.mock_llm, self.mock_tts, flow_config=self.static_flow_config ) From 9a95c5bbef038deabffb74dcf6dc0eff48d34823 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 09:58:13 -0500 Subject: [PATCH 4/6] Update FlowManager input for all examples --- examples/dynamic/insurance_anthropic.py | 4 +++- examples/dynamic/insurance_gemini.py | 4 +++- examples/dynamic/insurance_openai.py | 4 +++- examples/static/food_ordering.py | 2 +- examples/static/movie_explorer_anthropic.py | 2 +- examples/static/movie_explorer_gemini.py | 2 +- examples/static/movie_explorer_openai.py | 2 +- examples/static/patient_intake.py | 2 +- examples/static/restaurant_reservation.py | 2 +- examples/static/travel_planner.py | 2 +- 10 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/dynamic/insurance_anthropic.py b/examples/dynamic/insurance_anthropic.py index 67e5d66..148e2db 100644 --- a/examples/dynamic/insurance_anthropic.py +++ b/examples/dynamic/insurance_anthropic.py @@ -421,7 +421,9 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with transition callback - flow_manager = FlowManager(task, llm, tts, transition_callback=handle_insurance_transition) + flow_manager = FlowManager( + task=task, llm=llm, tts=tts, transition_callback=handle_insurance_transition + ) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/dynamic/insurance_gemini.py b/examples/dynamic/insurance_gemini.py index b4864d7..197ed29 100644 --- a/examples/dynamic/insurance_gemini.py +++ b/examples/dynamic/insurance_gemini.py @@ -409,7 +409,9 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with transition callback - flow_manager = FlowManager(task, llm, tts, transition_callback=handle_insurance_transition) + flow_manager = FlowManager( + task=task, llm=llm, tts=tts, transition_callback=handle_insurance_transition + ) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/dynamic/insurance_openai.py b/examples/dynamic/insurance_openai.py index 35e44fc..4646005 100644 --- a/examples/dynamic/insurance_openai.py +++ b/examples/dynamic/insurance_openai.py @@ -417,7 +417,9 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with transition callback - flow_manager = FlowManager(task, llm, tts, transition_callback=handle_insurance_transition) + flow_manager = FlowManager( + task=task, llm=llm, tts=tts, transition_callback=handle_insurance_transition + ) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/food_ordering.py b/examples/static/food_ordering.py index db91d91..139de03 100644 --- a/examples/static/food_ordering.py +++ b/examples/static/food_ordering.py @@ -331,7 +331,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager in static mode - flow_manager = FlowManager(task, llm, tts, flow_config=flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/movie_explorer_anthropic.py b/examples/static/movie_explorer_anthropic.py index 50d1c37..9fd2439 100644 --- a/examples/static/movie_explorer_anthropic.py +++ b/examples/static/movie_explorer_anthropic.py @@ -481,7 +481,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/movie_explorer_gemini.py b/examples/static/movie_explorer_gemini.py index e8e0631..55aeaef 100644 --- a/examples/static/movie_explorer_gemini.py +++ b/examples/static/movie_explorer_gemini.py @@ -481,7 +481,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/movie_explorer_openai.py b/examples/static/movie_explorer_openai.py index 4e9830e..29324db 100644 --- a/examples/static/movie_explorer_openai.py +++ b/examples/static/movie_explorer_openai.py @@ -475,7 +475,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/patient_intake.py b/examples/static/patient_intake.py index 2098249..8614a4a 100644 --- a/examples/static/patient_intake.py +++ b/examples/static/patient_intake.py @@ -470,7 +470,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with LLM - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/restaurant_reservation.py b/examples/static/restaurant_reservation.py index d38f6cf..40e303e 100644 --- a/examples/static/restaurant_reservation.py +++ b/examples/static/restaurant_reservation.py @@ -230,7 +230,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with LLM - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): diff --git a/examples/static/travel_planner.py b/examples/static/travel_planner.py index 0ed9fb0..4160952 100644 --- a/examples/static/travel_planner.py +++ b/examples/static/travel_planner.py @@ -393,7 +393,7 @@ async def main(): task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) # Initialize flow manager with LLM - flow_manager = FlowManager(task, llm, tts, flow_config) + flow_manager = FlowManager(task=task, llm=llm, tts=tts, flow_config=flow_config) @transport.event_handler("on_first_participant_joined") async def on_first_participant_joined(transport, participant): From 3f66562115a9821b0d9a4ca69bfa4912ee8c4662 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 10:01:22 -0500 Subject: [PATCH 5/6] More docstrings linting changes --- pyproject.toml | 3 +++ tests/__init__.py | 1 + 2 files changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index fedecef..146be75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,9 @@ select = [ "D", # Docstring rules ] ignore = ["D212"] +exclude = [ + "examples" +] [tool.ruff.pydocstyle] convention = "google" \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..5ed335a 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for pipecat-flows.""" From ba350b2157f88d2a19a0387399f9e3d306a42640 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 7 Dec 2024 10:08:29 -0500 Subject: [PATCH 6/6] Fix FlowManager inputs in test_manager.py --- tests/test_manager.py | 62 ++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/tests/test_manager.py b/tests/test_manager.py index ffd45e4..4806385 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -87,7 +87,10 @@ async def test_static_flow_initialization(self): - Message queueing to task """ flow_manager = FlowManager( - self.mock_task, self.mock_llm, self.mock_tts, flow_config=self.static_flow_config + task=self.mock_task, + llm=self.mock_llm, + tts=self.mock_tts, + flow_config=self.static_flow_config, ) # Verify static mode setup @@ -116,9 +119,9 @@ async def transition_callback(function_name, args, flow_manager): pass flow_manager = FlowManager( - self.mock_task, - self.mock_llm, - self.mock_tts, + task=self.mock_task, + llm=self.mock_llm, + tts=self.mock_tts, transition_callback=transition_callback, ) @@ -144,7 +147,10 @@ async def transition_callback(function_name, args, flow_manager): async def test_static_flow_transitions(self): """Test transitions in static flow.""" flow_manager = FlowManager( - self.mock_task, self.mock_llm, self.mock_tts, flow_config=self.static_flow_config + task=self.mock_task, + llm=self.mock_llm, + tts=self.mock_tts, + flow_config=self.static_flow_config, ) await flow_manager.initialize([]) @@ -162,9 +168,9 @@ async def transition_callback(function_name, args, flow_manager): await flow_manager.set_node("dynamic_node", self.sample_node_config) flow_manager = FlowManager( - self.mock_task, - self.mock_llm, - self.mock_tts, + task=self.mock_task, + llm=self.mock_llm, + tts=self.mock_tts, transition_callback=transition_callback, ) await flow_manager.initialize([]) @@ -180,7 +186,7 @@ async def transition_callback(function_name, args, flow_manager): async def test_node_validation(self): """Test node configuration validation.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Initialize first # Test missing messages @@ -197,7 +203,7 @@ async def test_node_validation(self): async def test_function_registration(self): """Test function registration with LLM.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Set node with function @@ -211,7 +217,7 @@ async def test_function_registration(self): async def test_action_execution(self): """Test execution of pre and post actions.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm, self.mock_tts) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm, tts=self.mock_tts) await flow_manager.initialize([]) # Add actions to node config @@ -229,7 +235,7 @@ async def test_action_execution(self): async def test_error_handling(self): """Test error handling in flow manager.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) # Test initialization before setting node with self.assertRaises(FlowTransitionError): @@ -245,7 +251,7 @@ async def test_error_handling(self): async def test_state_management(self): """Test state management across nodes.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Set state data @@ -257,7 +263,7 @@ async def test_state_management(self): async def test_multiple_function_registration(self): """Test registration of multiple functions.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Create node config with multiple functions @@ -284,7 +290,7 @@ async def test_multiple_function_registration(self): async def test_initialize_already_initialized(self): """Test initializing an already initialized flow manager.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Try to initialize again @@ -294,7 +300,7 @@ async def test_initialize_already_initialized(self): async def test_register_action(self): """Test registering custom actions.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) async def custom_action(action): pass @@ -304,7 +310,7 @@ async def custom_action(action): async def test_call_handler_variations(self): """Test different handler signature variations.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Test handler with args @@ -323,7 +329,7 @@ async def handler_no_args(): async def test_transition_func_error_handling(self): """Test error handling in transition functions.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) async def error_handler(args): @@ -349,7 +355,7 @@ async def result_callback(result): async def test_node_validation_edge_cases(self): """Test edge cases in node validation.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Test function with missing name @@ -393,7 +399,7 @@ def capture_warning(msg, *args, **kwargs): async def test_pre_post_actions(self): """Test pre and post actions in set_node.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Create node config with pre and post actions @@ -417,7 +423,7 @@ async def failing_transition(function_name, args, flow_manager): raise ValueError("Transition error") flow_manager = FlowManager( - self.mock_task, self.mock_llm, transition_callback=failing_transition + task=self.mock_task, llm=self.mock_llm, transition_callback=failing_transition ) await flow_manager.initialize([]) @@ -432,7 +438,7 @@ async def result_callback(result): async def test_register_function_error_handling(self): """Test error handling in function registration.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Mock LLM to raise error on register_function @@ -444,7 +450,7 @@ async def test_register_function_error_handling(self): async def test_action_execution_error_handling(self): """Test error handling in action execution.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Create node config with actions that will fail @@ -468,7 +474,7 @@ async def test_action_execution_error_handling(self): async def test_update_llm_context_error_handling(self): """Test error handling in LLM context updates.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) # Mock task to raise error on queue_frames @@ -481,7 +487,7 @@ async def test_update_llm_context_error_handling(self): async def test_handler_callback_completion(self): """Test handler completion callback and logging.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) async def test_handler(args): @@ -502,7 +508,7 @@ async def result_callback(result): async def test_handler_removal_all_formats(self): """Test handler removal from different function configurations.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) async def dummy_handler(args): @@ -534,7 +540,7 @@ async def dummy_handler(args): async def test_function_declarations_processing(self): """Test processing of function declarations format.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) async def test_handler(args): @@ -574,7 +580,7 @@ async def test_handler(args): async def test_direct_handler_format(self): """Test processing of direct handler format.""" - flow_manager = FlowManager(self.mock_task, self.mock_llm) + flow_manager = FlowManager(task=self.mock_task, llm=self.mock_llm) await flow_manager.initialize([]) async def test_handler(args):