From 17de8fb091883c4801306ecfc78e5d4186dc6922 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 13 Dec 2023 21:57:50 -0800 Subject: [PATCH] Add safeguard on tokens returned by functions (#576) * swapping out hardcoded str for prefix (forgot to include in #569) * add extra failout when the summarizer tries to run on a single message * added function response validation code, currently will truncate responses based on character count * added return type hints (functions/tools should either return strings or None) * discuss function output length in custom function section * made the truncation more informative --- docs/functions.md | 12 +++++-- memgpt/agent.py | 13 +++++-- memgpt/constants.py | 3 ++ memgpt/functions/function_sets/base.py | 16 ++++----- memgpt/functions/functions.py | 8 ++--- memgpt/utils.py | 48 +++++++++++++++++++++++++- 6 files changed, 82 insertions(+), 18 deletions(-) diff --git a/docs/functions.md b/docs/functions.md index 02abe76164..7805ba0695 100644 --- a/docs/functions.md +++ b/docs/functions.md @@ -34,6 +34,12 @@ There are three steps to adding more MemGPT functions: The functions you write MUST have proper docstrings and type hints - this is because MemGPT will use these docstrings and types to automatically create a JSON schema that is used in the LLM prompt. Use the docstrings and types annotations from the [example functions](https://github.com/cpacker/MemGPT/blob/main/memgpt/functions/function_sets/base.py) for guidance. +!!! warning "Function output length" + + Your custom function should always return a string that is **capped in length**. If your string goes over the specified limit, it will be truncated internaly. This is to prevent potential context overflows caused by uncapped string returns (for example, a rogue HTTP request that returns a string larger than the LLM context window). + + If you return any type other than `str` (e.g. `dict``) in your custom functions, MemGPT will attempt to cast the result to a string (and truncate the result if it is too long). It is preferable to return strings - think of your function returning a natural language description of the outcome (see the D20 example below). + In this simple example we'll give MemGPT the ability to roll a [D20 die](https://en.wikipedia.org/wiki/D20_System). First, let's create a python file `~/.memgpt/functions/d20.py`, and write some code that uses the `random` library to "roll a die": @@ -41,7 +47,7 @@ First, let's create a python file `~/.memgpt/functions/d20.py`, and write some import random -def roll_d20(self) -> int: +def roll_d20(self) -> str: """ Simulate the roll of a 20-sided die (d20). @@ -55,7 +61,9 @@ def roll_d20(self) -> int: >>> roll_d20() 15 # This is an example output and may vary each time the function is called. """ - return random.randint(1, 20) + dice_role_outcome = random.randint(1, 20) + output_string = f"You rolled a {dice_role_outcome}" + return output_string ``` Notice how we used [type hints](https://docs.python.org/3/library/typing.html) and [docstrings](https://peps.python.org/pep-0257/#multi-line-docstrings) to describe how the function works. **These are required**, if you do not include them MemGPT will not be able to "link" to your function. This is because MemGPT needs a JSON schema description of how your function works, which we automatically generate for you using the type hints and docstring (which you write yourself). diff --git a/memgpt/agent.py b/memgpt/agent.py index 3777cb4436..2f7592f3d5 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -9,7 +9,7 @@ from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from memgpt.memory import CoreMemory as Memory, summarize_messages from memgpt.openai_tools import create, is_context_overflow_error -from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff +from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response from memgpt.constants import ( FIRST_MESSAGE_ATTEMPTS, MESSAGE_SUMMARY_WARNING_FRAC, @@ -518,7 +518,8 @@ def handle_ai_response(self, response_message): self.interface.function_message(f"Running {function_name}({function_args})") try: function_args["self"] = self # need to attach self to arg since it's dynamically linked - function_response_string = function_to_call(**function_args) + function_response = function_to_call(**function_args) + function_response_string = validate_function_response(function_response) function_args.pop("self", None) function_response = package_function_response(True, function_response_string) function_failed = False @@ -723,7 +724,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True) pass message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message - printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") + if len(message_sequence_to_summarize) == 1: + # This prevents a potential infinite loop of summarizing the same message over and over + raise LLMError( + f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]" + ) + else: + printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") # We can't do summarize logic properly if context_window is undefined if self.config.context_window is None: diff --git a/memgpt/constants.py b/memgpt/constants.py index a722a64c7b..0c0564a74a 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -55,6 +55,9 @@ CORE_MEMORY_PERSONA_CHAR_LIMIT = 2000 CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000 +# Function return limits +FUNCTION_RETURN_CHAR_LIMIT = 2000 + MAX_PAUSE_HEARTBEATS = 360 # in min MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo" diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index 38d7528618..fd49d367d7 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -11,7 +11,7 @@ # If the function fails, throw an exception -def send_message(self, message: str): +def send_message(self, message: str) -> Optional[str]: """ Sends a message to the human user. @@ -37,7 +37,7 @@ def send_message(self, message: str): """ -def pause_heartbeats(self, minutes: int): +def pause_heartbeats(self, minutes: int) -> Optional[str]: minutes = min(MAX_PAUSE_HEARTBEATS, minutes) # Record the current time @@ -51,7 +51,7 @@ def pause_heartbeats(self, minutes: int): pause_heartbeats.__doc__ = pause_heartbeats_docstring -def core_memory_append(self, name: str, content: str): +def core_memory_append(self, name: str, content: str) -> Optional[str]: """ Append to the contents of core memory. @@ -67,7 +67,7 @@ def core_memory_append(self, name: str, content: str): return None -def core_memory_replace(self, name: str, old_content: str, new_content: str): +def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: """ Replace to the contents of core memory. To delete memories, use an empty string for new_content. @@ -84,7 +84,7 @@ def core_memory_replace(self, name: str, old_content: str, new_content: str): return None -def conversation_search(self, query: str, page: Optional[int] = 0): +def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using case-insensitive string matching. @@ -107,7 +107,7 @@ def conversation_search(self, query: str, page: Optional[int] = 0): return results_str -def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0): +def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using a date range. @@ -131,7 +131,7 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona return results_str -def archival_memory_insert(self, content: str): +def archival_memory_insert(self, content: str) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. @@ -145,7 +145,7 @@ def archival_memory_insert(self, content: str): return None -def archival_memory_search(self, query: str, page: Optional[int] = 0): +def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. diff --git a/memgpt/functions/functions.py b/memgpt/functions/functions.py index e5296f7a28..af1ced56ff 100644 --- a/memgpt/functions/functions.py +++ b/memgpt/functions/functions.py @@ -69,18 +69,18 @@ def load_all_function_sets(merge=True): except ModuleNotFoundError as e: # Handle missing module imports missing_package = str(e).split("'")[1] # Extract the name of the missing package - print(f"Warning: skipped loading python file '{module_full_path}'!") + print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!") print( f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT." ) continue except SyntaxError as e: # Handle syntax errors in the module - print(f"Warning: skipped loading python file '{file}' due to a syntax error: {e}") + print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}") continue except Exception as e: # Handle other general exceptions - print(f"Warning: skipped loading python file '{file}': {e}") + print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}") continue else: # For built-in scripts, use the existing method @@ -89,7 +89,7 @@ def load_all_function_sets(merge=True): module = importlib.import_module(full_module_name) except Exception as e: # Handle other general exceptions - print(f"Warning: skipped loading python module '{full_module_name}': {e}") + print(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}") continue try: diff --git a/memgpt/utils.py b/memgpt/utils.py index 7b0ea1b8b0..46f2fab359 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -6,7 +6,7 @@ import os import tiktoken import memgpt -from memgpt.constants import MEMGPT_DIR +from memgpt.constants import MEMGPT_DIR, FUNCTION_RETURN_CHAR_LIMIT, CLI_WARNING_PREFIX # TODO: what is this? # DEBUG = True @@ -88,6 +88,52 @@ def parse_json(string): raise e +def validate_function_response(function_response_string: any, strict: bool = False) -> str: + """Check to make sure that a function used by MemGPT returned a valid response + + Responses need to be strings (or None) that fall under a certain text count limit. + """ + if not isinstance(function_response_string, str): + # Soft correction for a few basic types + + if function_response_string is None: + # function_response_string = "Empty (no function output)" + function_response_string = "None" # backcompat + + elif isinstance(function_response_string, dict): + if strict: + # TODO add better error message + raise ValueError(function_response_string) + + # Allow dict through since it will be cast to json.dumps() + try: + # TODO find a better way to do this that won't result in double escapes + function_response_string = json.dumps(function_response_string) + except: + raise ValueError(function_response_string) + + else: + if strict: + # TODO add better error message + raise ValueError(function_response_string) + + # Try to convert to a string, but throw a warning to alert the user + try: + function_response_string = str(function_response_string) + except: + raise ValueError(function_response_string) + + # Now check the length and make sure it doesn't go over the limit + # TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window) + if len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT: + print( + f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated" + ) + function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]" + + return function_response_string + + def list_agent_config_files(sort="last_modified"): """List all agent config files, ignoring dotfiles.""" agent_dir = os.path.join(MEMGPT_DIR, "agents")