Skip to content

Commit

Permalink
feat(anthropic_engine): add citation support in stream mode
Browse files Browse the repository at this point in the history
- Added `citations` field to `Message` class in `conversation_model.py`.
- Implemented citation handling functions (`_handle_citations`, `get_current_citations`, `report_number_of_citations`, `show_citations`) in `conversation_tab.py`.
- Added 'Q' hotkey and context menu option for showing citations in the GUI.
- Extended `ProviderCapability` with `CITATION` for citation processing support.
- Updated `AnthropicEngine` to handle citation-related events in streaming completion responses.
- Enabled citation processing in document attachments for anthropic engine.
  • Loading branch information
AAClause committed Feb 16, 2025
1 parent 8cb8235 commit 2bab418
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 7 deletions.
2 changes: 2 additions & 0 deletions basilisk/conversation/conversation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import enum
from datetime import datetime
from typing import Any

from pydantic import BaseModel, Field, field_validator
from upath import UPath
Expand Down Expand Up @@ -47,6 +48,7 @@ class Message(BaseModel):
role: MessageRoleEnum
content: str
attachments: list[AttachmentFile | ImageFile] | None = Field(default=None)
citations: list[dict[str, Any]] | None = Field(default=None)


class MessageBlock(BaseModel):
Expand Down
128 changes: 126 additions & 2 deletions basilisk/gui/conversation_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,107 @@ def navigate_message(self, previous: bool):
start, end = self.get_range_for_current_message()
current_message = self.messages.GetRange(start, end)
self._handle_accessible_output(current_message)
self.report_number_of_citations()

def _handle_citations(self, citations: list[dict[str, Any]]) -> str:
"""Format a list of citations for display.
Args:
citations: The list of citations to format
Returns:
A formatted string containing the citations
"""
citations_str = []
for i, citation in enumerate(citations):
location_text = ""
cited_text = citation.get("cited_text")
document_index = citation.get("document_index")
document_title = citation.get("document_title")
match citation.get("type"):
case "char_location":
start_char_index = citation.get("start_char_index", 0)
end_char_index = citation.get("end_char_index", 0)
location_text = _("C.{start} .. {end}").format(
start=start_char_index, end=end_char_index
)
case "page_location":
start_page_number = citation.get("start_page_number", 0)
end_page_number = citation.get("end_page_number", 0)
location_text = _("P.{start} .. {end}").format(
start=start_page_number, end=end_page_number
)
case _:
location_text = _("Unknown location")
log.warning(f"Unknown citation type: {citation}")
if document_index is not None:
if document_title:
location_text = _(
"{document_title} / {location_text}"
).format(
document_title=document_title,
location_text=location_text,
)
else:
location_text = _(
"Document {document_index} / {location_text}"
).format(
document_index=document_index,
location_text=location_text,
)
if cited_text:
citations_str.append(
_("{location_text}: “{cited_text}”").format(
location_text=location_text,
cited_text=cited_text.strip(),
)
)
return "\n_--_\n".join(citations_str)

def get_current_citations(self) -> list[dict[str, Any]]:
"""Get the citations for the current message.
Returns:
The list of citations for the current message
"""
cursor_pos = self.messages.GetInsertionPoint()
self.message_segment_manager.absolute_position = cursor_pos
message_block = (
self.message_segment_manager.current_segment.message_block()
)
if not message_block:
wx.Bell()
return []
return message_block.response.citations

def report_number_of_citations(self):
"""Report the number of citations for the current message."""
citations = self.get_current_citations()
if not citations:
return
nb_citations = len(citations)
self.SetStatusText(
_("%d citations in the current message") % nb_citations
)

def show_citations(self, event: wx.CommandEvent | None = None):
"""Show the citations for the current message.
Args:
event: The event that triggered the action
"""
citations = self.get_current_citations()
if not citations:
self._handle_accessible_output(
_("No citations for this message"), braille=True
)
wx.Bell()
return
citations_str = self._handle_citations(citations)
if not citations_str:
wx.Bell()
return
ReadOnlyMessageDialog(self, _("Citations"), citations_str).ShowModal()

def go_to_previous_message(self, event: wx.CommandEvent | None = None):
"""Navigate to the previous message in the conversation.
Expand Down Expand Up @@ -1052,6 +1153,7 @@ def on_messages_key_down(self, event: wx.KeyEvent):
- C: Copy current message
- B: Move to start of message
- N: Move to end of message
- Q: Show citations for current message
- Shift+Delete: Remove current message block
- F3: Search in messages (forward)
- Shift+F3: Search in messages (backward)
Expand All @@ -1069,6 +1171,7 @@ def on_messages_key_down(self, event: wx.KeyEvent):
key_actions = {
(wx.MOD_SHIFT, wx.WXK_SPACE): self.on_toggle_speak_stream,
(wx.MOD_NONE, wx.WXK_SPACE): self.on_read_current_message,
(wx.MOD_NONE, ord('Q')): self.show_citations,
(wx.MOD_NONE, ord('J')): self.go_to_previous_message,
(wx.MOD_NONE, ord('K')): self.go_to_next_message,
(wx.MOD_NONE, ord('S')): self.on_select_message,
Expand Down Expand Up @@ -1123,6 +1226,15 @@ def on_messages_context_menu(self, event: wx.ContextMenuEvent):
menu.Append(item)
self.Bind(wx.EVT_MENU, self.on_read_current_message, item)

item = wx.MenuItem(
menu,
wx.ID_ANY,
# Translators: This is a label for the Messages area context menu in the main window
_("Show citations") + " (Q)",
)
menu.Append(item)
self.Bind(wx.EVT_MENU, self.show_citations, item)

item = wx.MenuItem(
menu,
wx.ID_ANY,
Expand Down Expand Up @@ -1775,8 +1887,20 @@ def _handle_completion(self, engine: BaseEngine, **kwargs: dict[str, Any]):
if self._stop_completion or global_vars.app_should_exit:
log.debug("Stopping completion")
break
new_block.response.content += chunk
wx.CallAfter(self._handle_completion_with_stream, chunk)
if isinstance(chunk, str):
new_block.response.content += chunk
wx.CallAfter(self._handle_completion_with_stream, chunk)
elif isinstance(chunk, tuple):
chunk_type, chunk_data = chunk
match chunk_type:
case "citation":
if not new_block.response.citations:
new_block.response.citations = []
new_block.response.citations.append(chunk_data)
case _:
log.warning(
f"Unknown chunk type in streaming response: {chunk_type}"
)
wx.CallAfter(self._post_completion_with_stream, new_block)
else:
new_block = engine.completion_response_without_stream(
Expand Down
2 changes: 2 additions & 0 deletions basilisk/provider_capability.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class ProviderCapability(enum.StrEnum):

# The provider supports document processing (excluding images)
DOCUMENT = enum.auto()
# The provider supports citation processing
CITATION = enum.auto()
# The provider supports image processing
IMAGE = enum.auto()
# The provider supports text processing
Expand Down
49 changes: 44 additions & 5 deletions basilisk/provider_engine/anthropic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterator

from anthropic import Anthropic
from anthropic.types import Message as AnthropicMessage
Expand Down Expand Up @@ -50,6 +50,7 @@ class AnthropicEngine(BaseEngine):
ProviderCapability.TEXT,
ProviderCapability.IMAGE,
ProviderCapability.DOCUMENT,
ProviderCapability.CITATION,
}
supported_attachment_formats: set[str] = {
"image/gif",
Expand Down Expand Up @@ -219,13 +220,21 @@ def convert_message(self, message: Message) -> dict:
elif mime_type.startswith("application/"):
source["data"] = attachment.encode_base64()
contents.append(
DocumentBlockParam(source=source, type="document")
DocumentBlockParam(
type="document",
source=source,
citations={"enabled": True},
)
)
elif mime_type in ("text/plain"):
source["data"] = attachment.read_as_str()
source["type"] = "text"
contents.append(
TextBlockParam(source=source, type="document")
TextBlockParam(
type="document",
source=source,
citations={"enabled": True},
)
)
return {"role": message.role.value, "content": contents}

Expand Down Expand Up @@ -268,7 +277,7 @@ def completion(

def completion_response_with_stream(
self, stream: Stream[MessageStreamEvent]
):
) -> Iterator[TextBlock | dict]:
"""Processes streaming response from Anthropic API.
Args:
Expand All @@ -280,7 +289,37 @@ def completion_response_with_stream(
for event in stream:
match event.type:
case "content_block_delta":
yield event.delta.text
match event.delta.type:
case "text_delta":
yield event.delta.text
case "citations_delta":
citation = event.delta.citation
citation_chunk_data = {
"type": citation.type,
"cited_text": citation.cited_text,
"document_index": citation.document_index,
"document_title": citation.document_title,
}
match citation.type:
case "char_location":
citation_chunk_data.update(
{
"start_char_index": citation.start_char_index,
"end_char_index": citation.end_char_index,
}
)
case "page_location":
citation_chunk_data.update(
{
"start_page_number": citation.start_page_number, # inclusive,
"end_page_number": citation.end_page_number, # exclusive
}
)
case _:
log.warning(
f"Unsupported citation type: {citation.type}"
)
yield ("citation", citation_chunk_data)

def completion_response_without_stream(
self, response: AnthropicMessage, new_block: MessageBlock, **kwargs
Expand Down

0 comments on commit 2bab418

Please sign in to comment.