From 4bf763beaec7e7d815b1c84a3f9efb7c90bcdb02 Mon Sep 17 00:00:00 2001 From: Ali Salimli <67149699+elisalimli@users.noreply.github.com> Date: Mon, 13 May 2024 19:10:53 +0400 Subject: [PATCH] fix: getting context window sizes of models without prefixes (#994) * fix: getting context window sizes of models without prefixes * feat: limit split counts to 1 --- libs/superagent/app/memory/buffer_memory.py | 23 ++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/libs/superagent/app/memory/buffer_memory.py b/libs/superagent/app/memory/buffer_memory.py index 96c2467fa..5ba2fad18 100644 --- a/libs/superagent/app/memory/buffer_memory.py +++ b/libs/superagent/app/memory/buffer_memory.py @@ -7,7 +7,23 @@ from app.memory.message import BaseMessage DEFAULT_TOKEN_LIMIT_RATIO = 0.75 -DEFAULT_TOKEN_LIMIT = 3000 +DEFAULT_TOKEN_LIMIT = 3072 + + +def get_context_window(model: str) -> int: + max_input_tokens = model_cost.get(model, {}).get("max_input_tokens") + + # Some models don't have a provider prefix in their name + # But they point to the same model + # Example: claude-3-haiku-20240307 and anthropic/claude-3-haiku-20240307 + if not max_input_tokens: + model_parts = model.split("/", 1) + if len(model_parts) > 1: + model_without_prefix = model_parts[1] + max_input_tokens = model_cost.get(model_without_prefix, {}).get( + "max_input_tokens", DEFAULT_TOKEN_LIMIT + ) + return max_input_tokens class BufferMemory(BaseMemory): @@ -21,8 +37,9 @@ def __init__( self.memory_store = memory_store self.tokenizer_fn = tokenizer_fn self.model = model - context_window = model_cost.get(self.model, {}).get("max_input_tokens") - self.context_window = max_tokens or context_window * DEFAULT_TOKEN_LIMIT_RATIO + self.context_window = ( + max_tokens or get_context_window(model=model) * DEFAULT_TOKEN_LIMIT_RATIO + ) def add_message(self, message: BaseMessage) -> None: self.memory_store.add_message(message)