Skip to content

Commit

Permalink
feat(provider_engine): add support for uploading text and PDF files t…
Browse files Browse the repository at this point in the history
…o Anthropic

Fixes #485
  • Loading branch information
AAClause committed Feb 1, 2025
1 parent 3e6137c commit 423db48
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 113 deletions.
14 changes: 13 additions & 1 deletion basilisk/conversation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,32 @@
from .attached_file import (
URL_PATTERN,
AttachmentFile,
AttachmentFileTypes,
ImageFile,
NotImageError,
get_mime_type,
parse_supported_attachment_formats,
)
from .conversation_helper import PROMPT_TITLE
from .conversation_model import (
Conversation,
Message,
MessageBlock,
MessageRoleEnum,
)
from .image_model import URL_PATTERN, ImageFile, ImageFileTypes, NotImageError

__all__ = [
"AttachmentFile",
"AttachmentFileTypes",
"Conversation",
"get_mime_type",
"ImageFile",
"ImageFileTypes",
"Message",
"MessageBlock",
"MessageRoleEnum",
"NotImageError",
"parse_supported_attachment_formats",
"PROMPT_TITLE",
"URL_PATTERN",
]
Original file line number Diff line number Diff line change
Expand Up @@ -105,32 +105,110 @@ def parse_supported_attachment_formats(
return wildcard


class ImageFileTypes(Enum):
def get_mime_type(path: str) -> str | None:
"""
Get the MIME type of a file.
"""
return mimetypes.guess_type(path)[0]


class AttachmentFileTypes(Enum):
UNKNOWN = "unknown"
IMAGE_LOCAL = "local"
IMAGE_MEMORY = "memory"
IMAGE_URL = "http"
LOCAL = "local"
MEMORY = "memory"
URL = "http"

@classmethod
def _missing_(cls, value: object) -> ImageFileTypes:
if isinstance(value, str) and value.lower() == "data":
return cls.IMAGE_URL
if isinstance(value, str) and value.lower() == "https":
return cls.IMAGE_URL
def _missing_(cls, value: object) -> AttachmentFileTypes:
if isinstance(value, str) and value.lower() in ("data", "https"):
return cls.URL
if isinstance(value, str) and value.lower() == "zip":
return cls.IMAGE_LOCAL
return cls.LOCAL
return cls.UNKNOWN


class NotImageError(ValueError):
pass


class ImageFile(BaseModel):
class AttachmentFile(BaseModel):
location: PydanticUPath
name: str | None = None
description: str | None = None
size: int | None = None

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
if not self.name:
self.name = self._get_name()
self.size = self._get_size()

@property
def type(self) -> AttachmentFileTypes:
return AttachmentFileTypes(self.location.protocol)

def _get_name(self) -> str:
return self.location.name

def _get_size(self) -> int | None:
if self.type == AttachmentFileTypes.URL:
return None
return self.location.stat().st_size

@property
def display_size(self) -> str:
size = self.size
if size is None:
return _("Unknown")
if size < 1024:
return f"{size} B"
if size < 1024 * 1024:
return f"{size / 1024:.2f} KB"
return f"{size / 1024 / 1024:.2f} MB"

@property
def send_location(self) -> UPath:
return self.location

@property
def mime_type(self) -> str | None:
if self.type == AttachmentFileTypes.URL:
return None
mime_type, _ = mimetypes.guess_type(self.send_location)
return mime_type

@property
def display_location(self):
location = str(self.location)
if location.startswith("data:"):
location = f"{location[:50]}...{location[-10:]}"
return location

@staticmethod
def remove_location(location: UPath):
log.debug(f"Removing image at {location}")
try:
fs = location.fs
fs.rm(location.path)
except Exception as e:
log.error(f"Error deleting image at {location}: {e}")

def read_as_str(self):
with self.location.open(mode="r") as file:
return file.read()

def encode_base64(self) -> str:
with self.location.open(mode="rb") as file:
return base64.b64encode(file.read()).decode("utf-8")

def __del__(self):
if self.type == AttachmentFileTypes.URL:
return
if self.type == AttachmentFileTypes.MEMORY:
self.remove_location(self.location)


class ImageFile(AttachmentFile):
dimensions: tuple[int, int] | None = None
resize_location: PydanticUPath | None = Field(default=None, exclude=True)

Expand All @@ -150,7 +228,7 @@ def build_from_url(cls, url: str) -> ImageFile:
dimensions = get_image_dimensions(BytesIO(r.content))
return cls(
location=url,
type=ImageFileTypes.IMAGE_URL,
type=AttachmentFileTypes.URL,
size=size,
description=content_type,
dimensions=dimensions,
Expand Down Expand Up @@ -188,39 +266,17 @@ def validate_location(

def __init__(self, /, **data: Any) -> None:
super().__init__(**data)
if not self.name:
self.name = self._get_name()
self.size = self._get_size()
if not self.dimensions:
self.dimensions = self._get_dimensions()

__init__.__pydantic_base_init__ = True

@property
def type(self) -> ImageFileTypes:
return ImageFileTypes(self.location.protocol)

def _get_name(self) -> str:
return self.location.name

def _get_size(self) -> int | None:
if self.type == ImageFileTypes.IMAGE_URL:
return None
return self.location.stat().st_size

@property
def display_size(self) -> str:
size = self.size
if size is None:
return _("Unknown")
if size < 1024:
return f"{size} B"
if size < 1024 * 1024:
return f"{size / 1024:.2f} KB"
return f"{size / 1024 / 1024:.2f} MB"
def send_location(self) -> UPath:
return self.resize_location or self.location

def _get_dimensions(self) -> tuple[int, int] | None:
if self.type == ImageFileTypes.IMAGE_URL:
if self.type == AttachmentFileTypes.URL:
return None
with self.location.open(mode="rb") as image_file:
return get_image_dimensions(image_file)
Expand All @@ -235,7 +291,7 @@ def display_dimensions(self) -> str:
def resize(
self, conv_folder: UPath, max_width: int, max_height: int, quality: int
):
if ImageFileTypes.IMAGE_URL == self.type:
if AttachmentFileTypes.URL == self.type:
return
log.debug("Resizing image")
resize_location = conv_folder.joinpath(
Expand All @@ -253,10 +309,6 @@ def resize(
)
self.resize_location = resize_location if success else None

@property
def send_location(self) -> UPath:
return self.resize_location or self.location

@measure_time
def encode_image(self) -> str:
if self.size and self.size > 1024 * 1024 * 1024:
Expand All @@ -266,29 +318,15 @@ def encode_image(self) -> str:
with self.send_location.open(mode="rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

@property
def mime_type(self) -> str | None:
if self.type == ImageFileTypes.IMAGE_URL:
return None
mime_type, _ = mimetypes.guess_type(self.send_location)
return mime_type

@property
def url(self) -> str:
if not isinstance(self.type, ImageFileTypes):
if not isinstance(self.type, AttachmentFileTypes):
raise ValueError("Invalid image type")
if self.type == ImageFileTypes.IMAGE_URL:
if self.type == AttachmentFileTypes.URL:
return str(self.location)
base64_image = self.encode_image()
return f"data:{self.mime_type};base64,{base64_image}"

@property
def display_location(self):
location = str(self.location)
if location.startswith("data:image/"):
location = f"{location[:50]}...{location[-10:]}"
return location

@staticmethod
def remove_location(location: UPath):
log.debug(f"Removing image at {location}")
Expand All @@ -299,9 +337,6 @@ def remove_location(location: UPath):
log.error(f"Error deleting image at {location}: {e}")

def __del__(self):
if self.type == ImageFileTypes.IMAGE_URL:
return
if self.resize_location:
self.remove_location(self.resize_location)
if self.type == ImageFileTypes.IMAGE_MEMORY:
self.remove_location(self.location)
super().__del__()
10 changes: 6 additions & 4 deletions basilisk/conversation/conversation_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from basilisk.config import conf
from basilisk.decorators import measure_time

from .image_model import ImageFile, ImageFileTypes
from .attached_file import AttachmentFile, AttachmentFileTypes, ImageFile

if TYPE_CHECKING:
from .conversation_model import Conversation
Expand All @@ -23,11 +23,13 @@


def save_attachments(
attachments: list[ImageFile], attachment_path: str, fs: ZipFileSystem
attachments: list[AttachmentFile | ImageFile],
attachment_path: str,
fs: ZipFileSystem,
):
attachment_mapping = {}
for attachment in attachments:
if attachment.type == ImageFileTypes.IMAGE_URL:
if attachment.type == AttachmentFileTypes.URL:
continue
new_location = f"{attachment_path}/{attachment.location.name}"
with attachment.location.open(mode="rb") as attachment_file:
Expand Down Expand Up @@ -56,7 +58,7 @@ def create_conv_main_file(conversation: Conversation, fs: ZipFileSystem):

def restore_attachments(attachments: list[ImageFile], storage_path: UPath):
for attachment in attachments:
if attachment.type == ImageFileTypes.IMAGE_URL:
if attachment.type == AttachmentFileTypes.URL:
continue
new_path = storage_path / attachment.location.name
with attachment.location.open(mode="rb") as attachment_file:
Expand Down
4 changes: 2 additions & 2 deletions basilisk/conversation/conversation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from basilisk.provider_ai_model import AIModelInfo

from .attached_file import AttachmentFile, ImageFile
from .conversation_helper import create_bskc_file, open_bskc_file
from .image_model import ImageFile


class MessageRoleEnum(Enum):
Expand All @@ -21,7 +21,7 @@ class MessageRoleEnum(Enum):
class Message(BaseModel):
role: MessageRoleEnum
content: str
attachments: list[ImageFile] | None = Field(default=None)
attachments: list[AttachmentFile | ImageFile] | None = Field(default=None)


class MessageBlock(BaseModel):
Expand Down
Loading

0 comments on commit 423db48

Please sign in to comment.