diff --git a/elm/base.py b/elm/base.py index 09dea3f..409489a 100644 --- a/elm/base.py +++ b/elm/base.py @@ -53,6 +53,11 @@ class ApiBase(ABC): } """Optional mappings for unusual Azure names to tiktoken/openai names.""" + TOKENIZER_PATTERNS = ('gpt-4o', 'gpt-4-32k', 'gpt-4') + """Order-prioritized list of model sub-strings to look for in model name + to send to tokenizer. As an alternative to alias lookup, this will use the + tokenizer pattern if found in the model string""" + def __init__(self, model=None): """ Parameters @@ -348,7 +353,7 @@ def get_embedding(cls, text): return embedding @classmethod - def count_tokens(cls, text, model): + def count_tokens(cls, text, model, fallback_model='gpt-4'): """Return the number of tokens in a string. Parameters @@ -357,6 +362,10 @@ def count_tokens(cls, text, model): Text string to get number of tokens for model : str specification of OpenAI model to use (e.g., "gpt-3.5-turbo") + fallback_model : str, default='gpt-4' + Model to be used for tokenizer if input model can't be found + in :obj:`TOKENIZER_ALIASES` and doesn't have any easily + noticeable patterns. Returns ------- @@ -364,7 +373,15 @@ def count_tokens(cls, text, model): Number of tokens in text """ - token_model = cls.TOKENIZER_ALIASES.get(model, model) + if model in cls.TOKENIZER_ALIASES: + token_model = cls.TOKENIZER_ALIASES[model] + else: + token_model = fallback_model + for pattern in cls.TOKENIZER_PATTERNS: + if pattern in model: + token_model = pattern + break + encoding = tiktoken.encoding_for_model(token_model) return len(encoding.encode(text))