diff --git a/basilisk/conversation/__init__.py b/basilisk/conversation/__init__.py index 02b4cf1d..56e1c1bf 100644 --- a/basilisk/conversation/__init__.py +++ b/basilisk/conversation/__init__.py @@ -6,6 +6,7 @@ AttachmentFileTypes, ImageFile, NotImageError, + build_from_url, get_mime_type, parse_supported_attachment_formats, ) @@ -21,6 +22,7 @@ __all__ = [ "AttachmentFile", "AttachmentFileTypes", + "build_from_url", "Conversation", "get_mime_type", "ImageFile", diff --git a/basilisk/conversation/attached_file.py b/basilisk/conversation/attached_file.py index ebfe658e..dc92e2ed 100644 --- a/basilisk/conversation/attached_file.py +++ b/basilisk/conversation/attached_file.py @@ -29,10 +29,7 @@ log = logging.getLogger(__name__) -URL_PATTERN = re.compile( - r'(https?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+|data:image/\S+)', - re.IGNORECASE, -) +URL_PATTERN = re.compile(r'https?://[^\s<>"]+|data:\S+', re.IGNORECASE) def get_image_dimensions(reader: BufferedReader) -> tuple[int, int]: @@ -130,6 +127,50 @@ def get_mime_type(path: str) -> str | None: return mimetypes.guess_type(path)[0] +@measure_time +def build_from_url(url: str) -> AttachmentFile: + """Fetch a file from a given URL and create an AttachmentFile instance. + + This class method retrieves a file from the specified URL and constructs an AttachmentFile with metadata about the file. + + Args: + url: The URL of the file to retrieve. + + Returns: + An instance of AttachmentFile with details about the retrieved file. + + Raises: + httpx.HTTPError: If there is an error during the HTTP request. + + Example: + file = build_from_url("https://example.com/file.pdf") + image = build_from_url("https://example.com/image.jpg") + """ + r = httpx.get(url, follow_redirects=True) + r.raise_for_status() + size = r.headers.get("Content-Length") + if size and size.isdigit(): + size = int(size) + mime_type = r.headers.get("content-type", None) + if not mime_type: + raise NotImageError("No MIME type found") + if mime_type.startswith("image/"): + dimensions = get_image_dimensions(BytesIO(r.content)) + return ImageFile( + location=url, + type=AttachmentFileTypes.URL, + size=size, + mime_type=mime_type, + dimensions=dimensions, + ) + return AttachmentFile( + location=url, + type=AttachmentFileTypes.URL, + size=size, + mime_type=mime_type, + ) + + class AttachmentFileTypes(enum.StrEnum): """Enumeration of file types based on their source location.""" @@ -181,6 +222,7 @@ class AttachmentFile(BaseModel): name: str | None = None description: str | None = None size: int | None = None + mime_type: str | None = None @field_serializer("location", mode="wrap") @classmethod @@ -242,19 +284,19 @@ def validate_location( raise ValueError("Invalid location") return value - def __init__(self, /, **data: Any) -> None: + def __init__(self, /, **kwargs: Any) -> None: """Initialize an AttachmentFile instance with optional data. If no name is provided, automatically generates a name using the internal _get_name() method. If no size is set, retrieves the file size using _get_size() method. Args: - data: Keyword arguments for initializing the AttachmentFile instance. Can include optional attributes like name and size. + kwargs: Keyword arguments for initializing the AttachmentFile instance. Can include optional attributes like name, size, and description. """ - super().__init__(**data) - if not self.name: - self.name = self._get_name() - self.size = self._get_size() + super().__init__(**kwargs) + self.name = self.name or self._get_name() + self.mime_type = kwargs.get("mime_type") or self._get_mime_type() + self.size = kwargs.get("size") or self._get_size() __init__.__pydantic_base_init__ = True @@ -326,8 +368,7 @@ def send_location(self) -> UPath: """ return getattr(self, "resize_location", None) or self.location - @property - def mime_type(self) -> str | None: + def _get_mime_type(self) -> str | None: """Get the MIME type of the file. Returns: @@ -363,11 +404,11 @@ def remove_location(location: UPath): except Exception as e: log.error(f"Error deleting file at {location}: {e}") - def read_as_str(self) -> str: - """Read the file as a string. + def read_as_plain_text(self) -> str: + """Read the file as a plain text string. Returns: - The contents of the file as a string. + The contents of the file as a plain text string. """ with self.send_location.open(mode="r") as file: return file.read() @@ -396,6 +437,18 @@ def get_display_info(self) -> tuple[str, str, str]: """ return self.name, self.display_size, self.display_location + @property + def url(self) -> str: + """Get the URL of the file. + + Returns: + The URL of the file, or the base64-encoded data if the file is in memory. + """ + if self.type == AttachmentFileTypes.URL: + return str(self.location) + base64_data = self.encode_base64() + return f"data:{self.mime_type};base64,{base64_data}" + class ImageFile(AttachmentFile): """Represents an image file in a conversation.""" @@ -403,47 +456,7 @@ class ImageFile(AttachmentFile): dimensions: tuple[int, int] | None = None resize_location: PydanticUPath | None = Field(default=None, exclude=True) - @classmethod - @measure_time - def build_from_url(cls, url: str) -> ImageFile: - """Fetch an image from a given URL and create an ImageFile instance. - - This class method retrieves an image from the specified URL, validates that it is an image, - and constructs an ImageFile with metadata about the image. - - Args: - url: The URL of the image to retrieve. - - Returns: - An instance of ImageFile with details about the retrieved image. - - Raises: - httpx.HTTPError: If there is an error during the HTTP request. - NotImageError: If the URL does not point to an image (content type is not image/*). - - Example: - image = ImageFile.build_from_url("https://example.com/image.jpg") - """ - r = httpx.get(url, follow_redirects=True) - r.raise_for_status() - content_type = r.headers.get("content-type", "") - if not content_type.startswith("image/"): - e = NotImageError("URL does not point to an image") - e.content_type = content_type - raise e - size = r.headers.get("Content-Length") - if size and size.isdigit(): - size = int(size) - dimensions = get_image_dimensions(BytesIO(r.content)) - return cls( - location=url, - type=AttachmentFileTypes.URL, - size=size, - description=content_type, - dimensions=dimensions, - ) - - def __init__(self, /, **data: Any) -> None: + def __init__(self, /, **kwargs: Any) -> None: """Initialize an ImageFile instance with optional data. If no name is provided, automatically generates a name using the internal _get_name() method. @@ -451,11 +464,10 @@ def __init__(self, /, **data: Any) -> None: If no dimensions are specified, determines image dimensions using _get_dimensions() method. Args: - data: Keyword arguments for initializing the ImageFile instance. Can include optional attributes like name, size, and dimensions. + kwargs: Keyword arguments for initializing the ImageFile instance. Can include optional attributes like name, size, and dimensions. """ - super().__init__(**data) - if not self.dimensions: - self.dimensions = self._get_dimensions() + super().__init__(**kwargs) + self.dimensions = self.dimensions or self._get_dimensions() __init__.__pydantic_base_init__ = True @@ -511,7 +523,7 @@ def resize( self.resize_location = resize_location if success else None @measure_time - def encode_image(self) -> str: + def encode_base64(self) -> str: """Encode the image file as a base64 string. Returns: @@ -523,19 +535,6 @@ def encode_image(self) -> str: ) return super().encode_base64() - @property - def url(self) -> str: - """Get the URL of the image file. - - Returns: - The URL of the image file, or the base64-encoded image data if the image is in memory. - """ - if not isinstance(self.type, AttachmentFileTypes): - raise ValueError("Invalid image type") - if self.type == AttachmentFileTypes.URL: - return str(self.location) - return f"data:{self.mime_type};base64,{self.encode_image()}" - @property def display_location(self): """Get the display location of the image file. diff --git a/basilisk/gui/conversation_tab.py b/basilisk/gui/conversation_tab.py index 22d53826..364ed7d8 100644 --- a/basilisk/gui/conversation_tab.py +++ b/basilisk/gui/conversation_tab.py @@ -39,8 +39,8 @@ Message, MessageBlock, MessageRoleEnum, - NotImageError, SystemMessage, + build_from_url, get_mime_type, parse_supported_attachment_formats, ) @@ -392,8 +392,8 @@ def on_attachments_context_menu(self, event: wx.ContextMenuEvent): item = wx.MenuItem( menu, wx.ID_ANY, - # Translators: This is a label for remove selected image in the context menu - _("Remove selected image") + " Shift+Del", + # Translators: This is a label for remove selected attachment in the context menu + _("Remove selected attachment") + " Shift+Del", ) menu.Append(item) self.Bind(wx.EVT_MENU, self.on_attachments_remove, item) @@ -427,11 +427,11 @@ def on_attachments_context_menu(self, event: wx.ContextMenuEvent): item = wx.MenuItem( menu, wx.ID_ANY, - # Translators: This is a label for add image URL in the context menu - _("Add image URL...") + " Ctrl+U", + # Translators: This is a label for add attachment URL in the context menu + _("Add attachment URL...") + " Ctrl+U", ) menu.Append(item) - self.Bind(wx.EVT_MENU, self.add_image_url_dlg, item) + self.Bind(wx.EVT_MENU, self.add_attachment_url_dlg, item) self.attachments_list.PopupMenu(menu) menu.Destroy() @@ -441,8 +441,8 @@ def on_attachments_key_down(self, event: wx.KeyEvent): Supports: - Ctrl+C: Copy file location - - Ctrl+V: Paste image - - Delete: Remove selected image + - Ctrl+V: Paste attachments + - Delete: Remove selected attachment Args: event: The keyboard event @@ -468,7 +468,7 @@ def on_attachments_paste(self, event: wx.CommandEvent): Supports multiple clipboard data types: - Files: Adds files directly to the conversation - Text: - - If a valid URL is detected, adds the image URL + - If a valid URL is detected, adds the attachmentURL - Otherwise, pastes text into the prompt input - Bitmap images: Saves the image to a temporary file and adds it to the conversation @@ -488,8 +488,8 @@ def on_attachments_paste(self, event: wx.CommandEvent): clipboard.GetData(text_data) text = text_data.GetText() if re.fullmatch(URL_PATTERN, text): - log.info("Pasting URL from clipboard, adding image") - self.add_image_url_thread(text) + log.info("Pasting URL from clipboard, adding attachment") + self.add_attachment_url_thread(text) else: log.info("Pasting text from clipboard") self.prompt.WriteText(text) @@ -543,17 +543,17 @@ def add_attachments_dlg(self, event: wx.CommandEvent = None): self.add_attachments(paths) file_dialog.Destroy() - def add_image_url_dlg(self, event: wx.CommandEvent | None): - """Open a dialog to input an image URL and add it to the conversation. + def add_attachment_url_dlg(self, event: wx.CommandEvent | None): + """Open a dialog to input an attachment URL and add it to the conversation. Args: - event: Event triggered by the add image URL action + event: Event triggered by the add attachment URL action """ url_dialog = wx.TextEntryDialog( self, - # Translators: This is a label for image URL in conversation tab - message=_("Enter the URL of the image:"), - caption=_("Add image URL"), + # Translators: This is a label for enter URL in add attachment dialog + message=_("Enter the URL of the file to attach:"), + caption=_("Add attachment from URL"), ) if url_dialog.ShowModal() != wx.ID_OK: return @@ -565,76 +565,49 @@ def add_image_url_dlg(self, event: wx.CommandEvent | None): _("Invalid URL, bad format."), _("Error"), wx.OK | wx.ICON_ERROR ) return - self.add_image_url_thread(url) + self.add_attachment_url_thread(url) url_dialog.Destroy() - def force_image_from_url(self, url: str, content_type: str): - """Handle adding an image from a URL with a non-image content type. - - Displays a warning message to the user and prompts for confirmation to proceed. - - Args: - url: The URL of the image - content_type: The content type of the URL - """ - log.warning( - f"The {url} URL seems to not point to an image. The content type is {content_type}." - ) - force_add = wx.MessageBox( - # Translators: This message is displayed when the image URL seems to not point to an image. - _( - "The URL seems to not point to an image (content type: %s). Do you want to continue?" - ) - % content_type, - _("Warning"), - wx.YES_NO | wx.ICON_WARNING | wx.NO_DEFAULT, - ) - if force_add == wx.YES: - log.info("Forcing image addition") - self.add_attachments([ImageFile(location=url)]) - - def add_image_from_url(self, url: str): - """Add an image to the conversation from a URL. + def add_attachment_from_url(self, url: str): + """Add an attachment to the conversation from a URL. Args: - url: The URL of the image to add + url: The URL of the file to attach """ - image_file = None + attachment_file = None try: - image_file = ImageFile.build_from_url(url) + attachment_file = build_from_url(url) except HTTPError as err: wx.CallAfter( wx.MessageBox, - # Translators: This message is displayed when the image URL returns an HTTP error. + # Translators: This message is displayed when the HTTP error occurs while adding a file from a URL. _("HTTP error %s.") % err, _("Error"), wx.OK | wx.ICON_ERROR, ) return - except NotImageError as err: - wx.CallAfter(self.force_image_from_url, url, err.content_type) except BaseException as err: - log.error(err) + log.error(err, exc_info=True) wx.CallAfter( wx.MessageBox, - # Translators: This message is displayed when an error occurs while getting image dimensions. - _("Error getting image dimensions: %s") % err, + # Translators: This message is displayed when an error occurs while adding a file from a URL. + _("Error adding attachment from URL: %s") % err, _("Error"), wx.OK | wx.ICON_ERROR, ) return - wx.CallAfter(self.add_attachments, [image_file]) + wx.CallAfter(self.add_attachments, [attachment_file]) self.task = None @ensure_no_task_running - def add_image_url_thread(self, url: str): - """Start a thread to add an image to the conversation from a URL. + def add_attachment_url_thread(self, url: str): + """Start a thread to add an attachment to the conversation from a URL. Args: - url: The URL of the image to add + url: The URL of the file to attach """ self.task = threading.Thread( - target=self.add_image_from_url, args=(url,) + target=self.add_attachment_from_url, args=(url,) ) self.task.start() @@ -1158,15 +1131,12 @@ def _check_attachments_valid(self) -> bool: invalid_found = False attachments_copy = self.attachment_files[:] for attachment in attachments_copy: - if ( - attachment.mime_type not in supported_attachment_formats - or not attachment.location.exists() - ): - self.attachment_files.remove(attachment) + if attachment.mime_type not in supported_attachment_formats: msg = ( _( - "This attachment format is not supported by the current provider. Source:" + "This attachment format is not supported by the current provider. Source: %s" ) + % attachment.location if attachment.mime_type not in supported_attachment_formats else _("The attachment file does not exist: %s") % attachment.location @@ -1185,6 +1155,7 @@ def on_submit(self, event: wx.CommandEvent): if not self.submit_btn.IsEnabled(): return if not self._check_attachments_valid(): + self.attachments_list.SetFocus() return if not self.prompt.GetValue() and not self.attachment_files: self.prompt.SetFocus() diff --git a/basilisk/gui/main_frame.py b/basilisk/gui/main_frame.py index b137eb6a..525b8bac 100644 --- a/basilisk/gui/main_frame.py +++ b/basilisk/gui/main_frame.py @@ -656,7 +656,7 @@ def on_add_attachments( ) return if from_url: - current_tab.add_image_url_dlg(event) + current_tab.add_attachment_url_dlg(event) else: current_tab.add_attachments_dlg() diff --git a/basilisk/provider_engine/anthropic_engine.py b/basilisk/provider_engine/anthropic_engine.py index 0a144133..de02e090 100644 --- a/basilisk/provider_engine/anthropic_engine.py +++ b/basilisk/provider_engine/anthropic_engine.py @@ -13,13 +13,12 @@ from anthropic import Anthropic from anthropic.types import Message as AnthropicMessage from anthropic.types import TextBlock -from anthropic.types.document_block_param import DocumentBlockParam -from anthropic.types.image_block_param import ImageBlockParam, Source -from anthropic.types.text_block_param import TextBlockParam from basilisk.conversation import ( + AttachmentFile, AttachmentFileTypes, Conversation, + ImageFile, Message, MessageBlock, MessageRoleEnum, @@ -213,6 +212,58 @@ def models(self) -> list[ProviderAIModel]: ), ] + def get_attachment_source( + self, attachment: AttachmentFile | ImageFile + ) -> dict: + """Get the source for the attachment. + + Args: + attachment: Attachment to process. + + Returns: + Attachment source data. + """ + if attachment.type == AttachmentFileTypes.URL: + return {"type": "url", "url": attachment.url} + elif attachment.type != AttachmentFileTypes.UNKNOWN: + source = {"media_type": attachment.mime_type} + match attachment.mime_type.split("/")[0]: + case "image" | "application": + source["type"] = "base64" + source["data"] = attachment.encode_base64() + case "text": + source["type"] = "text" + source["data"] = attachment.read_as_plain_text() + case _: + raise ValueError( + f"Unsupported attachment type: {attachment.type}" + ) + return source + + def get_attachment_extras( + self, attachment: AttachmentFile | ImageFile + ) -> dict: + """Get the extras for the attachment. + + Args: + attachment: Attachment to process. + + Returns: + Attachment extra data. + """ + extras = {} + match attachment.mime_type.split('/')[0]: + case "image": + extras["type"] = "image" + case "application" | "text": + extras["type"] = "document" + extras["citations"] = {"enabled": True} + case _: + raise ValueError( + f"Unsupported attachment type: {attachment.type}" + ) + return extras + def convert_message(self, message: Message) -> dict: """Converts internal message format to Anthropic API format. @@ -227,37 +278,13 @@ def convert_message(self, message: Message) -> dict: # TODO: implement "context" and "title" for documents # TODO: add support for custom content document format for attachment in message.attachments: - mime_type = attachment.mime_type - if attachment.type != AttachmentFileTypes.URL: - source = Source( - data=None, - media_type=attachment.mime_type, - type="base64", + source = self.get_attachment_source(attachment) + extras = self.get_attachment_extras(attachment) + if not source or not extras: + raise ValueError( + f"Unsupported attachment type: {attachment.type}" ) - if mime_type.startswith("image/"): - source["data"] = attachment.encode_image() - contents.append( - ImageBlockParam(source=source, type="image") - ) - elif mime_type.startswith("application/"): - source["data"] = attachment.encode_base64() - contents.append( - DocumentBlockParam( - type="document", - source=source, - citations={"enabled": True}, - ) - ) - elif mime_type in ("text/plain"): - source["data"] = attachment.read_as_str() - source["type"] = "text" - contents.append( - TextBlockParam( - type="document", - source=source, - citations={"enabled": True}, - ) - ) + contents.append({"source": source} | extras) return {"role": message.role.value, "content": contents} prepare_message_request = convert_message @@ -357,13 +384,56 @@ def _handle_thinking( else: return event.delta.thinking, started + def _handle_content_block_stop( + self, thinking_content_started: bool, current_block_type: str + ) -> tuple[str | None, bool]: + """Handles content block stop events from the stream. + + Args: + thinking_content_started: Flag indicating if thinking content has started. + current_block_type: Type of the current content block. + + Returns: + Tuple containing optional yield content and updated thinking_started flag. + """ + if thinking_content_started and current_block_type == "thinking": + return "\n```\n\n", False + return None, thinking_content_started + + def _handle_content_block_delta( + self, event: MessageStreamEvent, thinking_content_started: bool + ) -> tuple[str | tuple[str, dict] | None, bool]: + """Handles content block delta events from the stream. + + Args: + event: The stream event to process. + thinking_content_started: Flag indicating if thinking content has started. + + Returns: + Tuple containing yield content and updated thinking_started flag. + """ + match event.delta.type: + case "text_delta": + return event.delta.text, thinking_content_started + case "thinking_delta": + text, updated_started = self._handle_thinking( + thinking_content_started, event + ) + return text, updated_started + case "citations_delta": + return ( + ("citation", self._handle_citation(event.delta.citation)), + thinking_content_started, + ) + return None, thinking_content_started + def completion_response_with_stream( self, stream: Stream[MessageStreamEvent] ) -> Iterator[TextBlock | dict]: """Processes streaming response from Anthropic API. Args: - stream: Stream of message events from the API. + stream: Stream of message events from the API. Yields: Text content from each event or thinking content. @@ -375,35 +445,26 @@ def completion_response_with_stream( case "content_block_start": current_block_type = event.content_block.type case "content_block_stop": - if ( - thinking_content_started - and current_block_type == "thinking" - ): - thinking_content_started = False - yield "\n```\n\n" + content, thinking_content_started = ( + self._handle_content_block_stop( + thinking_content_started, current_block_type + ) + ) + if content: + yield content case "content_block_delta": - match event.delta.type: - case "text_delta": - yield event.delta.text - case "thinking_delta": - text, thinking_content_started = ( - self._handle_thinking( - thinking_content_started, event - ) - ) - yield text - case "citations_delta": - yield ( - "citation", - self._handle_citation(event.delta.citation), - ) + content, thinking_content_started = ( + self._handle_content_block_delta( + event, thinking_content_started + ) + ) + if content: + yield content case "message_stop": if thinking_content_started: yield "\n```\n" break - # ruff: noqa: C901 - def completion_response_without_stream( self, response: AnthropicMessage, new_block: MessageBlock, **kwargs ) -> MessageBlock: diff --git a/basilisk/server_thread.py b/basilisk/server_thread.py index 444c4850..9cce486c 100644 --- a/basilisk/server_thread.py +++ b/basilisk/server_thread.py @@ -124,7 +124,9 @@ def manage_rcv_data(self, data: bytes) -> None: url = data.split(':', 1)[1].strip() if '\n' in url: url, name = url.split('\n', 1) - wx.CallAfter(self.frame.current_tab.add_image_url_thread, url) + wx.CallAfter( + self.frame.current_tab.add_attachment_url_thread, url + ) else: log.error(f"no action for data: {data}") diff --git a/tests/conversation/test_attachment.py b/tests/conversation/test_attachment.py index 1ebf6ff8..fcc893fd 100644 --- a/tests/conversation/test_attachment.py +++ b/tests/conversation/test_attachment.py @@ -11,6 +11,7 @@ AttachmentFile, AttachmentFileTypes, ImageFile, + build_from_url, parse_supported_attachment_formats, ) @@ -79,7 +80,7 @@ def test_attachment_mime_type(self, text_file): def test_attachment_read_as_str(self, text_file): """Test reading attachment file as string.""" attachment = AttachmentFile(location=text_file) - assert attachment.read_as_str() == "test content" + assert attachment.read_as_plain_text() == "test content" def test_attachment_get_display_info(self, text_file): """Test getting display info tuple.""" @@ -262,7 +263,7 @@ def test_image_from_url(self, httpx_mock): }, ) - image = ImageFile.build_from_url(test_url) + image = build_from_url(test_url) assert image.type == AttachmentFileTypes.URL assert str(image.location) == test_url