Skip to content

Commit

Permalink
anthropic[patch]: add usage_metadata details (#27087)
Browse files Browse the repository at this point in the history
fixes #27087
  • Loading branch information
baskaryan authored Oct 4, 2024
1 parent e8e5d67 commit 0495b7f
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 41 deletions.
43 changes: 24 additions & 19 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
ToolCall,
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.messages.ai import InputTokenDetails, UsageMetadata
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
from langchain_core.output_parsers import (
JsonOutputKeyToolsParser,
Expand Down Expand Up @@ -766,12 +766,7 @@ def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
)
else:
msg = AIMessage(content=content)
# Collect token usage
msg.usage_metadata = {
"input_tokens": data.usage.input_tokens,
"output_tokens": data.usage.output_tokens,
"total_tokens": data.usage.input_tokens + data.usage.output_tokens,
}
msg.usage_metadata = _create_usage_metadata(data.usage)
return ChatResult(
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
Expand Down Expand Up @@ -1182,14 +1177,10 @@ def _make_message_chunk_from_anthropic_event(
message_chunk: Optional[AIMessageChunk] = None
# See https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/lib/streaming/_messages.py # noqa: E501
if event.type == "message_start" and stream_usage:
input_tokens = event.message.usage.input_tokens
usage_metadata = _create_usage_metadata(event.message.usage)
message_chunk = AIMessageChunk(
content="" if coerce_content_to_string else [],
usage_metadata=UsageMetadata(
input_tokens=input_tokens,
output_tokens=0,
total_tokens=input_tokens,
),
usage_metadata=usage_metadata,
)
elif (
event.type == "content_block_start"
Expand Down Expand Up @@ -1235,14 +1226,10 @@ def _make_message_chunk_from_anthropic_event(
tool_call_chunks=[tool_call_chunk], # type: ignore
)
elif event.type == "message_delta" and stream_usage:
output_tokens = event.usage.output_tokens
usage_metadata = _create_usage_metadata(event.usage)
message_chunk = AIMessageChunk(
content="",
usage_metadata=UsageMetadata(
input_tokens=0,
output_tokens=output_tokens,
total_tokens=output_tokens,
),
usage_metadata=usage_metadata,
response_metadata={
"stop_reason": event.delta.stop_reason,
"stop_sequence": event.delta.stop_sequence,
Expand All @@ -1257,3 +1244,21 @@ def _make_message_chunk_from_anthropic_event(
@deprecated(since="0.1.0", removal="0.3.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass


def _create_usage_metadata(anthropic_usage: BaseModel) -> UsageMetadata:
input_token_details: Dict = {
"cache_read": getattr(anthropic_usage, "cache_read_input_tokens", None),
"cache_creation": getattr(anthropic_usage, "cache_creation_input_tokens", None),
}

input_tokens = getattr(anthropic_usage, "input_tokens", 0)
output_tokens = getattr(anthropic_usage, "output_tokens", 0)
return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
input_token_details=InputTokenDetails(
**{k: v for k, v in input_token_details.items() if v is not None}
),
)
106 changes: 105 additions & 1 deletion libs/partners/anthropic/tests/integration_tests/test_standard.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Standard LangChain interface tests"""

from typing import Type
from pathlib import Path
from typing import List, Literal, Type, cast

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests

from langchain_anthropic import ChatAnthropic

REPO_ROOT_DIR = Path(__file__).parents[5]


class TestAnthropicStandard(ChatModelIntegrationTests):
@property
Expand All @@ -28,3 +32,103 @@ def supports_image_tool_message(self) -> bool:
@property
def supports_anthropic_inputs(self) -> bool:
return True

@property
def supported_usage_metadata_details(
self,
) -> List[
Literal[
"audio_input",
"audio_output",
"reasoning_output",
"cache_read_input",
"cache_creation_input",
]
]:
return ["cache_read_input", "cache_creation_input"]

def invoke_with_cache_creation_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md", "r") as f:
readme = f.read()

input_ = f"""What's langchain? Here's the langchain README:
{readme}
"""
return _invoke(
llm,
[
{
"role": "user",
"content": [
{
"type": "text",
"text": input_,
"cache_control": {"type": "ephemeral"},
}
],
}
],
stream,
)

def invoke_with_cache_read_input(self, *, stream: bool = False) -> AIMessage:
llm = ChatAnthropic(
model="claude-3-5-sonnet-20240620", # type: ignore[call-arg]
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"}, # type: ignore[call-arg]
)
with open(REPO_ROOT_DIR / "README.md", "r") as f:
readme = f.read()

input_ = f"""What's langchain? Here's the langchain README:
{readme}
"""

# invoke twice so first invocation is cached
_invoke(
llm,
[
{
"role": "user",
"content": [
{
"type": "text",
"text": input_,
"cache_control": {"type": "ephemeral"},
}
],
}
],
stream,
)
return _invoke(
llm,
[
{
"role": "user",
"content": [
{
"type": "text",
"text": input_,
"cache_control": {"type": "ephemeral"},
}
],
}
],
stream,
)


def _invoke(llm: ChatAnthropic, input_: list, stream: bool) -> AIMessage:
if stream:
full = None
for chunk in llm.stream(input_):
full = full + chunk if full else chunk # type: ignore[operator]
return cast(AIMessage, full)
else:
return cast(AIMessage, llm.invoke(input_))
64 changes: 43 additions & 21 deletions libs/partners/anthropic/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

import pytest
from anthropic.types import Message, TextBlock, Usage
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaMessage,
PromptCachingBetaUsage,
)
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.runnables import RunnableBinding
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, SecretStr
Expand Down Expand Up @@ -89,30 +92,49 @@ def test__format_output() -> None:
usage=Usage(input_tokens=2, output_tokens=1),
type="message",
)
expected = ChatResult(
generations=[
ChatGeneration(
message=AIMessage( # type: ignore[misc]
"bar",
usage_metadata={
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
},
)
),
],
llm_output={
"id": "foo",
"model": "baz",
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 2, "output_tokens": 1},
expected = AIMessage( # type: ignore[misc]
"bar",
usage_metadata={
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
"input_token_details": {},
},
)
llm = ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg]
actual = llm._format_output(anthropic_msg)
assert expected == actual
assert actual.generations[0].message == expected


def test__format_output_cached() -> None:
anthropic_msg = PromptCachingBetaMessage(
id="foo",
content=[TextBlock(type="text", text="bar")],
model="baz",
role="assistant",
stop_reason=None,
stop_sequence=None,
usage=PromptCachingBetaUsage(
input_tokens=2,
output_tokens=1,
cache_creation_input_tokens=3,
cache_read_input_tokens=4,
),
type="message",
)
expected = AIMessage( # type: ignore[misc]
"bar",
usage_metadata={
"input_tokens": 2,
"output_tokens": 1,
"total_tokens": 3,
"input_token_details": {"cache_creation": 3, "cache_read": 4},
},
)

llm = ChatAnthropic(model="test", anthropic_api_key="test") # type: ignore[call-arg, call-arg]
actual = llm._format_output(anthropic_msg)
assert actual.generations[0].message == expected


def test__merge_messages() -> None:
Expand Down

0 comments on commit 0495b7f

Please sign in to comment.