Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Add Support for Data URI #385

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/assets/openapi.json

Large diffs are not rendered by default.

79 changes: 37 additions & 42 deletions docs/docs/cli_v2.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ async def classify(
return scores, usage

async def image_embed(
self, *, images: List[Union[str, "ImageClassType"]]
self, *, images: List[Union[str, "ImageClassType", bytes]]
) -> tuple[list[EmbeddingReturnType], int]:
"""embed multiple images

Expand Down
239 changes: 239 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/data_uri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import mimetypes
import re
import sys
import textwrap
from base64 import b64decode as decode64
from base64 import b64encode as encode64
from dataclasses import dataclass
from typing import Any, Dict, MutableMapping, Optional, Tuple, TypeVar, Union

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self

from urllib.parse import quote, unquote

T = TypeVar("T")

MIMETYPE_REGEX = r"[\w]+\/[\w\-\+\.]+"
MIMETYPE_REGEX_AUDIO_IMAGE = r"(audio|image)\/[\w\-\+\.]+"
_MIMETYPE_RE = re.compile("^{}$".format(MIMETYPE_REGEX_AUDIO_IMAGE))

CHARSET_REGEX = r"[\w\-\+\.]+"
_CHARSET_RE = re.compile("^{}$".format(CHARSET_REGEX))

DATA_URI_REGEX = (
Copy link
Owner

@michaelfeil michaelfeil Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stikkireddy Moved the regex directly into the validation of the pydantic type.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

libs/infinity_emb/infinity_emb/transformer/vision/utils.py

i assume you will remove that unnecessary code then?

r"data:"
+ r"(?P<mimetype>{})?".format(MIMETYPE_REGEX)
+ r"(?:\;name\=(?P<name>[\w\.\-%!*'~\(\)]+))?"
+ r"(?:\;charset\=(?P<charset>{}))?".format(CHARSET_REGEX)
+ r"(?P<base64>\;base64)?"
+ r",(?P<data>.*)"
)
_DATA_URI_RE = re.compile(r"^{}$".format(DATA_URI_REGEX), re.DOTALL)


class InvalidMimeType(ValueError):
pass


class InvalidCharset(ValueError):
pass


class InvalidDataURI(ValueError):
pass


@dataclass
class DataURIHolder:
mimetype: Optional[str]
charset: Optional[str]
base64: bool
data: Union[str, bytes]


class DataURI(str):
@classmethod
def make(
cls,
mimetype: Optional[str],
charset: Optional[str],
base64: Optional[bool],
data: Union[str, bytes],
) -> Self:
parts = ["data:"]
if mimetype is not None:
if not _MIMETYPE_RE.match(mimetype):
raise InvalidMimeType("Invalid mimetype: %r" % mimetype)
parts.append(mimetype)
if charset is not None:
if not _CHARSET_RE.match(charset):
raise InvalidCharset("Invalid charset: %r" % charset)
parts.extend([";charset=", charset])
if base64:
parts.append(";base64")
_charset = charset or "utf-8"
if isinstance(data, bytes):
_data = data
else:
_data = bytes(data, _charset)
encoded_data = encode64(_data).decode(_charset).strip()
else:
encoded_data = quote(data)
parts.extend([",", encoded_data])
return cls("".join(parts))

@classmethod
def from_file(
cls,
filename: str,
charset: Optional[str] = None,
base64: Optional[bool] = True,
mimetype: Optional[str] = None,
) -> Self:
if mimetype is None:
mimetype, _ = mimetypes.guess_type(filename, strict=False)
with open(filename, "rb") as fp:
data = fp.read()

return cls.make(mimetype, charset, base64, data)

def __new__(cls, *args: Any, **kwargs: Any) -> Self:
uri = super(DataURI, cls).__new__(cls, *args, **kwargs)
uri._parse # Trigger any ValueErrors on instantiation.
return uri

def __repr__(self) -> str:
truncated = str(self)
if len(truncated) > 80:
truncated = truncated[:79] + "…"
return "DataURI(%s)" % (truncated,)

def wrap(self, width: int = 76) -> str:
return "\n".join(textwrap.wrap(self, width, break_on_hyphens=False))

@property
def mimetype(self) -> Optional[str]:
return self._parse[0]

@property
def name(self) -> Optional[str]:
name = self._parse[1]
if name is not None:
return unquote(name)
return name

@property
def charset(self) -> Optional[str]:
return self._parse[2]

@property
def is_base64(self) -> bool:
return self._parse[3]

@property
def data(self) -> bytes:
return self._parse[4]

def convert_to_data_uri_holder(self) -> DataURIHolder:
return DataURIHolder(
mimetype=self.mimetype,
charset=self.charset,
base64=self.is_base64,
data=self.data,
)

@property
def text(self) -> str:
if self.charset is None:
raise InvalidCharset("DataURI has no encoding set.")

return self.data.decode(self.charset)

@property
def is_valid(self) -> bool:
match = _DATA_URI_RE.match(self)
if not match:
return False
return True

@property
def _parse(
self,
) -> Tuple[Optional[str], Optional[str], Optional[str], bool, bytes]:
match = _DATA_URI_RE.match(self)
if match is None:
raise InvalidDataURI("Not a valid data URI: %r" % self)
mimetype = match.group("mimetype") or None
name = match.group("name") or None
charset = match.group("charset") or None
_charset = charset or "utf-8"

if match.group("base64"):
_data = bytes(match.group("data"), _charset)
data = decode64(_data)
else:
data = bytes(unquote(match.group("data")), _charset)

return mimetype, name, charset, bool(match.group("base64")), data

# Pydantic methods
@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.validate

@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
from pydantic_core import core_schema

# return core_schema.no_info_after_validator_function(cls, handler(str))
return core_schema.no_info_after_validator_function(
cls.validate, core_schema.str_schema()
)

@classmethod
def validate(
cls,
value: str,
values: Optional[MutableMapping[str, Any]] = None,
config: Any = None,
field: Any = None,
**kwargs: Any,
) -> Self:
if not isinstance(value, str):
raise TypeError("string required")

m = cls(value)
if not m.is_valid:
raise ValueError("invalid data-uri format")
return m

@classmethod
def __get_pydantic_json_schema__(
cls, schema: MutableMapping[str, Any], handler: Any
) -> Any:
json_schema = handler(schema)
json_schema.update(
pattern=DATA_URI_REGEX,
examples=[
"data:text/plain;charset=utf-8;base64,"
"VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"
],
)
return json_schema

@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
# __modify_schema__ should mutate the dict it receives in place,
# the returned value will be ignored
field_schema.update(
pattern=DATA_URI_REGEX,
examples=[
"data:text/plain;charset=utf-8;base64,VGhlIHF1aWNrIGJyb3duIGZveCBqdW1wZWQgb3ZlciB0aGUgbGF6eSBkb2cu"
],
)
22 changes: 22 additions & 0 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pydantic_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pydantic import AnyUrl, HttpUrl, StringConstraints

__all__ = [
"INPUT_STRING",
"ITEMS_LIMIT",
"ITEMS_LIMIT_SMALL",
"AnyUrl",
"HttpUrl",
]

# Note: adding artificial limit, this might reveal splitting
# issues on the client side
# and is not a hard limit on the server side.
INPUT_STRING = StringConstraints(max_length=8192 * 15, strip_whitespace=True)
ITEMS_LIMIT = {
"min_length": 1,
"max_length": 2048,
}
ITEMS_LIMIT_SMALL = {
"min_length": 1,
"max_length": 32,
}
33 changes: 16 additions & 17 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@
from pydantic import BaseModel, Field, conlist

try:
from pydantic import AnyUrl, HttpUrl, StringConstraints

# Note: adding artificial limit, this might reveal splitting
# issues on the client side
# and is not a hard limit on the server side.
INPUT_STRING = StringConstraints(max_length=8192 * 15, strip_whitespace=True)
ITEMS_LIMIT = {
"min_length": 1,
"max_length": 2048,
}
ITEMS_LIMIT_SMALL = {
"min_length": 1,
"max_length": 32,
}
from .data_uri import DataURI
from .pydantic_v2 import (
INPUT_STRING,
ITEMS_LIMIT,
ITEMS_LIMIT_SMALL,
HttpUrl,
)
except ImportError:
from pydantic import constr

Expand All @@ -49,12 +42,18 @@
"min_items": 1,
"max_items": 32,
}
HttpUrl, AnyUrl = str, str # type: ignore
HttpUrl = str # type: ignore
DataURI = str # type: ignore
DataURIorURL = Union[Annotated[DataURI, str], HttpUrl]

else:

class BaseModel: # type: ignore[no-redef]
pass

class DataURI: # type: ignore
pass

def Field(*args, **kwargs): # type: ignore
pass

Expand Down Expand Up @@ -83,10 +82,10 @@ class OpenAIEmbeddingInput(BaseModel):
class ImageEmbeddingInput(BaseModel):
input: Union[ # type: ignore
conlist( # type: ignore
Annotated[AnyUrl, HttpUrl],
DataURIorURL,
**ITEMS_LIMIT_SMALL,
),
Annotated[AnyUrl, HttpUrl],
DataURIorURL,
]
model: str = "default/not-specified"
encoding_format: EmbeddingEncodingFormat = EmbeddingEncodingFormat.float
Expand Down
2 changes: 1 addition & 1 deletion libs/infinity_emb/infinity_emb/inference/batch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ async def classify(
async def image_embed(
self,
*,
images: List[Union[str, "ImageClassType"]],
images: List[Union[str, "ImageClassType", bytes]],
) -> tuple[list[EmbeddingReturnType], int]:
"""Schedule a images and sentences to be embedded. Awaits until embedded.

Expand Down
Loading
Loading