Skip to content

Commit

Permalink
Add safeguard on tokens returned by functions (letta-ai#576)
Browse files Browse the repository at this point in the history
* swapping out hardcoded str for prefix (forgot to include in letta-ai#569)

* add extra failout when the summarizer tries to run on a single message

* added function response validation code, currently will truncate responses based on character count

* added return type hints (functions/tools should either return strings or None)

* discuss function output length in custom function section

* made the truncation more informative
  • Loading branch information
cpacker authored and norton120 committed Feb 15, 2024
1 parent 19df83c commit 17de8fb
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 18 deletions.
12 changes: 10 additions & 2 deletions docs/functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,20 @@ There are three steps to adding more MemGPT functions:
The functions you write MUST have proper docstrings and type hints - this is because MemGPT will use these docstrings and types to automatically create a JSON schema that is used in the LLM prompt. Use the docstrings and types annotations from the [example functions](https://github.com/cpacker/MemGPT/blob/main/memgpt/functions/function_sets/base.py) for guidance.
!!! warning "Function output length"
Your custom function should always return a string that is **capped in length**. If your string goes over the specified limit, it will be truncated internaly. This is to prevent potential context overflows caused by uncapped string returns (for example, a rogue HTTP request that returns a string larger than the LLM context window).
If you return any type other than `str` (e.g. `dict``) in your custom functions, MemGPT will attempt to cast the result to a string (and truncate the result if it is too long). It is preferable to return strings - think of your function returning a natural language description of the outcome (see the D20 example below).

In this simple example we'll give MemGPT the ability to roll a [D20 die](https://en.wikipedia.org/wiki/D20_System).

First, let's create a python file `~/.memgpt/functions/d20.py`, and write some code that uses the `random` library to "roll a die":
```python
import random
def roll_d20(self) -> int:
def roll_d20(self) -> str:
"""
Simulate the roll of a 20-sided die (d20).
Expand All @@ -55,7 +61,9 @@ def roll_d20(self) -> int:
>>> roll_d20()
15 # This is an example output and may vary each time the function is called.
"""
return random.randint(1, 20)
dice_role_outcome = random.randint(1, 20)
output_string = f"You rolled a {dice_role_outcome}"
return output_string
```

Notice how we used [type hints](https://docs.python.org/3/library/typing.html) and [docstrings](https://peps.python.org/pep-0257/#multi-line-docstrings) to describe how the function works. **These are required**, if you do not include them MemGPT will not be able to "link" to your function. This is because MemGPT needs a JSON schema description of how your function works, which we automatically generate for you using the type hints and docstring (which you write yourself).
Expand Down
13 changes: 10 additions & 3 deletions memgpt/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from memgpt.memory import CoreMemory as Memory, summarize_messages
from memgpt.openai_tools import create, is_context_overflow_error
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response
from memgpt.constants import (
FIRST_MESSAGE_ATTEMPTS,
MESSAGE_SUMMARY_WARNING_FRAC,
Expand Down Expand Up @@ -518,7 +518,8 @@ def handle_ai_response(self, response_message):
self.interface.function_message(f"Running {function_name}({function_args})")
try:
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response_string = function_to_call(**function_args)
function_response = function_to_call(**function_args)
function_response_string = validate_function_response(function_response)
function_args.pop("self", None)
function_response = package_function_response(True, function_response_string)
function_failed = False
Expand Down Expand Up @@ -723,7 +724,13 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True)
pass

message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
if len(message_sequence_to_summarize) == 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]"
)
else:
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")

# We can't do summarize logic properly if context_window is undefined
if self.config.context_window is None:
Expand Down
3 changes: 3 additions & 0 deletions memgpt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@
CORE_MEMORY_PERSONA_CHAR_LIMIT = 2000
CORE_MEMORY_HUMAN_CHAR_LIMIT = 2000

# Function return limits
FUNCTION_RETURN_CHAR_LIMIT = 2000

MAX_PAUSE_HEARTBEATS = 360 # in min

MESSAGE_CHATGPT_FUNCTION_MODEL = "gpt-3.5-turbo"
Expand Down
16 changes: 8 additions & 8 deletions memgpt/functions/function_sets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# If the function fails, throw an exception


def send_message(self, message: str):
def send_message(self, message: str) -> Optional[str]:
"""
Sends a message to the human user.
Expand All @@ -37,7 +37,7 @@ def send_message(self, message: str):
"""


def pause_heartbeats(self, minutes: int):
def pause_heartbeats(self, minutes: int) -> Optional[str]:
minutes = min(MAX_PAUSE_HEARTBEATS, minutes)

# Record the current time
Expand All @@ -51,7 +51,7 @@ def pause_heartbeats(self, minutes: int):
pause_heartbeats.__doc__ = pause_heartbeats_docstring


def core_memory_append(self, name: str, content: str):
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Expand All @@ -67,7 +67,7 @@ def core_memory_append(self, name: str, content: str):
return None


def core_memory_replace(self, name: str, old_content: str, new_content: str):
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace to the contents of core memory. To delete memories, use an empty string for new_content.
Expand All @@ -84,7 +84,7 @@ def core_memory_replace(self, name: str, old_content: str, new_content: str):
return None


def conversation_search(self, query: str, page: Optional[int] = 0):
def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
Expand All @@ -107,7 +107,7 @@ def conversation_search(self, query: str, page: Optional[int] = 0):
return results_str


def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0):
def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using a date range.
Expand All @@ -131,7 +131,7 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona
return results_str


def archival_memory_insert(self, content: str):
def archival_memory_insert(self, content: str) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
Expand All @@ -145,7 +145,7 @@ def archival_memory_insert(self, content: str):
return None


def archival_memory_search(self, query: str, page: Optional[int] = 0):
def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.
Expand Down
8 changes: 4 additions & 4 deletions memgpt/functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,18 @@ def load_all_function_sets(merge=True):
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
print(f"Warning: skipped loading python file '{module_full_path}'!")
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
print(
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
)
continue
except SyntaxError as e:
# Handle syntax errors in the module
print(f"Warning: skipped loading python file '{file}' due to a syntax error: {e}")
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
continue
except Exception as e:
# Handle other general exceptions
print(f"Warning: skipped loading python file '{file}': {e}")
print(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
continue
else:
# For built-in scripts, use the existing method
Expand All @@ -89,7 +89,7 @@ def load_all_function_sets(merge=True):
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
print(f"Warning: skipped loading python module '{full_module_name}': {e}")
print(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
continue

try:
Expand Down
48 changes: 47 additions & 1 deletion memgpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import tiktoken
import memgpt
from memgpt.constants import MEMGPT_DIR
from memgpt.constants import MEMGPT_DIR, FUNCTION_RETURN_CHAR_LIMIT, CLI_WARNING_PREFIX

# TODO: what is this?
# DEBUG = True
Expand Down Expand Up @@ -88,6 +88,52 @@ def parse_json(string):
raise e


def validate_function_response(function_response_string: any, strict: bool = False) -> str:
"""Check to make sure that a function used by MemGPT returned a valid response
Responses need to be strings (or None) that fall under a certain text count limit.
"""
if not isinstance(function_response_string, str):
# Soft correction for a few basic types

if function_response_string is None:
# function_response_string = "Empty (no function output)"
function_response_string = "None" # backcompat

elif isinstance(function_response_string, dict):
if strict:
# TODO add better error message
raise ValueError(function_response_string)

# Allow dict through since it will be cast to json.dumps()
try:
# TODO find a better way to do this that won't result in double escapes
function_response_string = json.dumps(function_response_string)
except:
raise ValueError(function_response_string)

else:
if strict:
# TODO add better error message
raise ValueError(function_response_string)

# Try to convert to a string, but throw a warning to alert the user
try:
function_response_string = str(function_response_string)
except:
raise ValueError(function_response_string)

# Now check the length and make sure it doesn't go over the limit
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
if len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
print(
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
)
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"

return function_response_string


def list_agent_config_files(sort="last_modified"):
"""List all agent config files, ignoring dotfiles."""
agent_dir = os.path.join(MEMGPT_DIR, "agents")
Expand Down

0 comments on commit 17de8fb

Please sign in to comment.