Skip to content

Commit

Permalink
refactor: support for multiple attachment types
Browse files Browse the repository at this point in the history
- Refactored conversation tab to handle various attachment types instead of just images.
- Updated GUI labels and methods to reflect attachment handling.
- Implemented parsing of supported attachment formats for file dialog filters.
- Modified engine files to specify supported attachment formats per provider.
  • Loading branch information
AAClause committed Feb 1, 2025
1 parent 20277c8 commit 3e6137c
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 57 deletions.
19 changes: 19 additions & 0 deletions basilisk/conversation/image_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,25 @@ def resize_image(
return True


def parse_supported_attachment_formats(
supported_attachment_formats: set[str],
) -> str:
"""
Parse the supported attachment formats into a wildcard string for use in file dialogs.
"""
wildcard_parts = []
for mime_type in sorted(supported_attachment_formats):
exts = mimetypes.guess_all_extensions(mime_type)
if exts:
log.debug(f"Adding wildcard for MIME type {mime_type}: {exts}")
wildcard_parts.append("*" + ";*".join(exts))
else:
log.warning(f"No extensions found for MIME type {mime_type}")

wildcard = ";".join(wildcard_parts)
return wildcard


class ImageFileTypes(Enum):
UNKNOWN = "unknown"
IMAGE_LOCAL = "local"
Expand Down
128 changes: 73 additions & 55 deletions basilisk/gui/conversation_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MessageBlock,
MessageRoleEnum,
NotImageError,
parse_supported_attachment_formats,
)
from basilisk.decorators import ensure_no_task_running
from basilisk.message_segment_manager import (
Expand Down Expand Up @@ -159,26 +160,30 @@ def init_ui(self):
sizer.Add(self.prompt, proportion=1, flag=wx.EXPAND)
self.prompt.SetFocus()

self.images_list_label = wx.StaticText(
self.attachments_list_label = wx.StaticText(
self,
# Translators: This is a label for models in the main window
label=_("&Images:"),
label=_("&Attachments:"),
)
sizer.Add(self.images_list_label, proportion=0, flag=wx.EXPAND)
self.images_list = wx.ListCtrl(
sizer.Add(self.attachments_list_label, proportion=0, flag=wx.EXPAND)
self.attachments_list = wx.ListCtrl(
self, size=(800, 100), style=wx.LC_REPORT
)
self.images_list.Bind(wx.EVT_CONTEXT_MENU, self.on_images_context_menu)
self.images_list.Bind(wx.EVT_KEY_DOWN, self.on_images_key_down)
self.images_list.InsertColumn(0, _("Name"))
self.images_list.InsertColumn(1, _("Size"))
self.images_list.InsertColumn(2, _("Dimensions"))
self.images_list.InsertColumn(3, _("Path"))
self.images_list.SetColumnWidth(0, 200)
self.images_list.SetColumnWidth(1, 100)
self.images_list.SetColumnWidth(2, 100)
self.images_list.SetColumnWidth(3, 200)
sizer.Add(self.images_list, proportion=0, flag=wx.ALL | wx.EXPAND)
self.attachments_list.Bind(
wx.EVT_CONTEXT_MENU, self.on_attachments_context_menu
)
self.attachments_list.Bind(
wx.EVT_KEY_DOWN, self.on_attachments_key_down
)
self.attachments_list.InsertColumn(0, _("Name"))
self.attachments_list.InsertColumn(1, _("Size"))
self.attachments_list.InsertColumn(2, _("Dimensions"))
self.attachments_list.InsertColumn(3, _("Path"))
self.attachments_list.SetColumnWidth(0, 200)
self.attachments_list.SetColumnWidth(1, 100)
self.attachments_list.SetColumnWidth(2, 100)
self.attachments_list.SetColumnWidth(3, 200)
sizer.Add(self.attachments_list, proportion=0, flag=wx.ALL | wx.EXPAND)
label = self.create_model_widget()
sizer.Add(label, proportion=0, flag=wx.EXPAND)
sizer.Add(self.model_list, proportion=0, flag=wx.ALL | wx.EXPAND)
Expand Down Expand Up @@ -237,6 +242,7 @@ def init_ui(self):
self.Bind(wx.EVT_CHAR_HOOK, self.on_char_hook)

def init_data(self, profile: Optional[config.ConversationProfile]):
self.refresh_attachments_list()
self.apply_profile(profile, True)
self.refresh_messages(need_clear=False)

Expand Down Expand Up @@ -267,16 +273,16 @@ def on_account_change(self, event: wx.CommandEvent):
ProviderCapability.STT in account.provider.engine_cls.capabilities
)

def on_images_context_menu(self, event: wx.ContextMenuEvent):
selected = self.images_list.GetFirstSelected()
def on_attachments_context_menu(self, event: wx.ContextMenuEvent):
selected = self.attachments_list.GetFirstSelected()
menu = wx.Menu()

if selected != wx.NOT_FOUND:
item = wx.MenuItem(
menu, wx.ID_ANY, _("Remove selected image") + " (Shift+Del)"
)
menu.Append(item)
self.Bind(wx.EVT_MENU, self.on_images_remove, item)
self.Bind(wx.EVT_MENU, self.on_attachments_remove, item)

item = wx.MenuItem(
menu, wx.ID_ANY, _("Copy image URL") + " (Ctrl+C)"
Expand All @@ -290,24 +296,24 @@ def on_images_context_menu(self, event: wx.ContextMenuEvent):
self.Bind(wx.EVT_MENU, self.on_image_paste, item)
item = wx.MenuItem(menu, wx.ID_ANY, _("Add image files..."))
menu.Append(item)
self.Bind(wx.EVT_MENU, self.add_image_files, item)
self.Bind(wx.EVT_MENU, self.add_attachments_dlg, item)

item = wx.MenuItem(menu, wx.ID_ANY, _("Add image URL..."))
menu.Append(item)
self.Bind(wx.EVT_MENU, self.add_image_url_dlg, item)

self.images_list.PopupMenu(menu)
self.attachments_list.PopupMenu(menu)
menu.Destroy()

def on_images_key_down(self, event: wx.KeyEvent):
def on_attachments_key_down(self, event: wx.KeyEvent):
key_code = event.GetKeyCode()
modifiers = event.GetModifiers()
if modifiers == wx.MOD_CONTROL and key_code == ord("C"):
self.on_copy_image_url(None)
if modifiers == wx.MOD_CONTROL and key_code == ord("V"):
self.on_image_paste(None)
if modifiers == wx.MOD_NONE and key_code == wx.WXK_DELETE:
self.on_images_remove(None)
self.on_attachments_remove(None)
event.Skip()

def on_image_paste(self, event: wx.CommandEvent):
Expand All @@ -317,7 +323,7 @@ def on_image_paste(self, event: wx.CommandEvent):
file_data = wx.FileDataObject()
clipboard.GetData(file_data)
paths = file_data.GetFilenames()
self.add_images(paths)
self.add_attachment(paths)
elif clipboard.IsSupported(wx.DataFormat(wx.DF_TEXT)):
log.debug("Pasting text from clipboard")
text_data = wx.TextDataObject()
Expand All @@ -344,22 +350,34 @@ def on_image_paste(self, event: wx.CommandEvent):
)
with path.open("wb") as f:
img.SaveFile(f, wx.BITMAP_TYPE_PNG)
self.add_images([ImageFile(location=path)])
self.add_attachment([ImageFile(location=path)])

else:
log.info("Unsupported clipboard data")

def add_image_files(self, event: wx.CommandEvent = None):
def add_attachments_dlg(self, event: wx.CommandEvent = None):
wilrdcard = parse_supported_attachment_formats(
self.current_engine.supported_attachment_formats
)
if not wilrdcard:
wx.MessageBox(
# Translators: This message is displayed when there are no supported attachment formats.
_("This provider does not support any attachment formats."),
_("Error"),
wx.OK | wx.ICON_ERROR,
)
return
wilrdcard = _("All supported formats") + f" ({wilrdcard})|{wilrdcard}"

file_dialog = wx.FileDialog(
self,
message=_("Select one or more image files"),
message=_("Select one or more files to attach"),
style=wx.FD_OPEN | wx.FD_FILE_MUST_EXIST | wx.FD_MULTIPLE,
wildcard=_("Image files")
+ " (*.png;*.jpeg;*.jpg;*.gif)|*.png;*.jpeg;*.jpg;*.gif",
wildcard=wilrdcard,
)
if file_dialog.ShowModal() == wx.ID_OK:
paths = file_dialog.GetPaths()
self.add_images(paths)
self.add_attachment(paths)
file_dialog.Destroy()

def add_image_url_dlg(self, event: wx.CommandEvent = None):
Expand Down Expand Up @@ -397,7 +415,7 @@ def force_image_from_url(self, url: str, content_type: str):
)
if force_add == wx.YES:
log.info("Forcing image addition")
self.add_image_files([ImageFile(location=url)])
self.add_attachments([ImageFile(location=url)])

def add_image_from_url(self, url: str):
image_file = None
Expand All @@ -424,7 +442,7 @@ def add_image_from_url(self, url: str):
wx.OK | wx.ICON_ERROR,
)
return
wx.CallAfter(self.add_images, [image_file])
wx.CallAfter(self.add_attachment, [image_file])

self.task = None

Expand All @@ -435,23 +453,23 @@ def add_image_url_thread(self, url: str):
)
self.task.start()

def on_images_remove(self, vent: wx.CommandEvent):
selection = self.images_list.GetFirstSelected()
def on_attachments_remove(self, vent: wx.CommandEvent):
selection = self.attachments_list.GetFirstSelected()
if selection == wx.NOT_FOUND:
return
self.image_files.pop(selection)
self.refresh_images_list()
if selection >= self.images_list.GetItemCount():
self.refresh_attachments_list()
if selection >= self.attachments_list.GetItemCount():
selection -= 1
if selection >= 0:
self.images_list.SetItemState(
self.attachments_list.SetItemState(
selection, wx.LIST_STATE_FOCUSED, wx.LIST_STATE_FOCUSED
)
else:
self.prompt.SetFocus()

def on_copy_image_url(self, event: wx.CommandEvent):
selected = self.images_list.GetFirstSelected()
selected = self.attachments_list.GetFirstSelected()
if selected == wx.NOT_FOUND:
return
url = self.image_files[selected].location
Expand All @@ -475,36 +493,36 @@ def refresh_accounts(self):
self.account_combo.SetSelection(0)
self.account_combo.SetFocus()

def refresh_images_list(self):
self.images_list.DeleteAllItems()
def refresh_attachments_list(self):
self.attachments_list.DeleteAllItems()
if not self.image_files:
self.images_list_label.Hide()
self.images_list.Hide()
self.attachments_list_label.Hide()
self.attachments_list.Hide()
self.Layout()
return
self.images_list_label.Show()
self.images_list.Show()
self.attachments_list_label.Show()
self.attachments_list.Show()
self.Layout()
for i, image in enumerate(self.image_files):
self.images_list.InsertItem(i, image.name)
self.images_list.SetItem(i, 1, image.display_size)
self.images_list.SetItem(i, 2, image.display_dimensions)
self.images_list.SetItem(i, 3, image.display_location)
self.images_list.SetItemState(
self.attachments_list.InsertItem(i, image.name)
self.attachments_list.SetItem(i, 1, image.display_size)
self.attachments_list.SetItem(i, 2, image.display_dimensions)
self.attachments_list.SetItem(i, 3, image.display_location)
self.attachments_list.SetItemState(
i, wx.LIST_STATE_FOCUSED, wx.LIST_STATE_FOCUSED
)
self.images_list.EnsureVisible(i)
self.attachments_list.EnsureVisible(i)

def add_images(self, paths: list[str | ImageFile]):
def add_attachment(self, paths: list[str | ImageFile]):
log.debug(f"Adding images: {paths}")
for path in paths:
if isinstance(path, ImageFile):
self.image_files.append(path)
else:
file = ImageFile(location=path)
self.image_files.append(file)
self.refresh_images_list()
self.images_list.SetFocus()
self.refresh_attachments_list()
self.attachments_list.SetFocus()

def on_config_change(self):
self.refresh_accounts()
Expand Down Expand Up @@ -961,7 +979,7 @@ def refresh_messages(self, need_clear: bool = True):
self.messages.Clear()
self.message_segment_manager.clear()
self.image_files.clear()
self.refresh_images_list()
self.refresh_attachments_list()
for block in self.conversation.messages:
self.display_new_block(block)

Expand Down Expand Up @@ -1189,7 +1207,7 @@ def _pre_handle_completion_with_stream(self, new_block: MessageBlock):
self.messages.SetInsertionPointEnd()
self.prompt.Clear()
self.image_files.clear()
self.refresh_images_list()
self.refresh_attachments_list()

def _handle_completion_with_stream(self, chunk: str):
self.stream_buffer += chunk
Expand Down Expand Up @@ -1298,7 +1316,7 @@ def _post_completion_without_stream(self, new_block: MessageBlock):
self._handle_accessible_output(new_block.response.content)
self.prompt.Clear()
self.image_files.clear()
self.refresh_images_list()
self.refresh_attachments_list()
if config.conf().conversation.focus_history_after_send:
self.messages.SetFocus()
self._end_task()
Expand Down
4 changes: 2 additions & 2 deletions basilisk/gui/main_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def screen_capture(

def post_screen_capture(self, imagefile: ImageFile | str):
log.debug("Screen capture received")
self.current_tab.add_images([imagefile])
self.current_tab.add_attachment([imagefile])
if not self.IsShown():
self.Show()
self.Restore()
Expand Down Expand Up @@ -469,7 +469,7 @@ def on_add_image(self, event, from_url=False):
if from_url:
current_tab.add_image_url_dlg()
else:
current_tab.add_image_files()
current_tab.add_attachments_dlg()

def on_transcribe_audio(
self, event: wx.Event, from_microphone: bool = False
Expand Down
8 changes: 8 additions & 0 deletions basilisk/provider_engine/anthropic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class AnthropicEngine(BaseEngine):
ProviderCapability.TEXT,
ProviderCapability.IMAGE,
}
supported_attachment_formats: set[str] = {
"image/gif",
"image/jpeg",
"image/png",
"image/webp",
"application/pdf",
"text/plain",
}

def __init__(self, account: Account) -> None:
super().__init__(account)
Expand Down
1 change: 1 addition & 0 deletions basilisk/provider_engine/base_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class BaseEngine(ABC):
capabilities: set[ProviderCapability] = set()
supported_attachment_formats: set[str] = {}

def __init__(self, account: Account) -> None:
self.account = account
Expand Down
7 changes: 7 additions & 0 deletions basilisk/provider_engine/gemini_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class GeminiEngine(BaseEngine):
ProviderCapability.TEXT,
ProviderCapability.IMAGE,
}
supported_attachment_formats: set[str] = {
"image/png",
"image/jpeg",
"image/webp",
"image/heic",
"image/heif",
}

def __init__(self, account: Account) -> None:
super().__init__(account)
Expand Down
6 changes: 6 additions & 0 deletions basilisk/provider_engine/openai_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ class OpenAIEngine(BaseEngine):
ProviderCapability.STT,
ProviderCapability.TTS,
}
supported_attachment_formats: set[str] = {
"image/gif",
"image/jpeg",
"image/png",
"image/webp",
}

def __init__(self, account: Account) -> None:
super().__init__(account)
Expand Down

0 comments on commit 3e6137c

Please sign in to comment.