diff --git a/gpt_engineer/ai.py b/gpt_engineer/ai.py index da923c3cf7..429e7f26d0 100644 --- a/gpt_engineer/ai.py +++ b/gpt_engineer/ai.py @@ -9,6 +9,7 @@ import openai import tiktoken +from langchain.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.chat_models import AzureChatOpenAI, ChatOpenAI from langchain.chat_models.base import BaseChatModel @@ -271,6 +272,24 @@ def format_token_usage_log(self) -> str: result += str(log.total_tokens) + "\n" return result + def usage_cost(self) -> float: + """ + Return the total cost in USD of the api usage. + + Returns + ------- + float + Cost in USD. + """ + prompt_price = MODEL_COST_PER_1K_TOKENS[self.model_name] + completion_price = MODEL_COST_PER_1K_TOKENS[self.model_name + "-completion"] + + result = 0 + for log in self.token_usage_log: + result += log.total_prompt_tokens / 1000 * prompt_price + result += log.total_completion_tokens / 1000 * completion_price + return result + def num_tokens(self, txt: str) -> int: """ Get the number of tokens in a text. diff --git a/gpt_engineer/main.py b/gpt_engineer/main.py index 68ed8807d1..a1c88b9444 100644 --- a/gpt_engineer/main.py +++ b/gpt_engineer/main.py @@ -96,6 +96,8 @@ def main( messages = step(ai, dbs) dbs.logs[step.__name__] = AI.serialize_messages(messages) + print("Total api cost: $ ", ai.usage_cost()) + if collect_consent(): collect_learnings(model, temperature, steps, dbs)