From d44882c1b5aa8ec59c73d5393ef677024a5e7e1a Mon Sep 17 00:00:00 2001 From: Yingchun Lai Date: Tue, 28 Jan 2025 10:52:01 +0800 Subject: [PATCH] refactor: reduce duplciate code by inheritance (#13073) --- .../text_embedding/text_embedding.py | 196 +----------------- 1 file changed, 9 insertions(+), 187 deletions(-) diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 40ea4dc0118026..b3d9167bc66b77 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -1,29 +1,13 @@ -import json -import time -from decimal import Decimal from typing import Optional -from urllib.parse import urljoin - -import numpy as np -import requests from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ( - AIModelEntity, - FetchFrom, - ModelPropertyKey, - ModelType, - PriceConfig, - PriceType, +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, ) -from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat -class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): +class PerfXCloudEmbeddingModel(OAICompatEmbeddingModel): """ Model class for an OpenAI API-compatible text embedding model. """ @@ -47,86 +31,10 @@ def _invoke( :return: embeddings result """ - # Prepare headers and payload for the request - headers = {"Content-Type": "application/json"} - - api_key = credentials.get("api_key") - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - endpoint_url: Optional[str] if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - endpoint_url = "https://cloud.perfxlab.cn/v1/" - else: - endpoint_url = credentials.get("endpoint_url") - assert endpoint_url is not None, "endpoint_url is required in credentials" - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - assert isinstance(endpoint_url, str) - endpoint_url = urljoin(endpoint_url, "embeddings") - - extra_model_kwargs = {} - if user: - extra_model_kwargs["user"] = user - - extra_model_kwargs["encoding_format"] = "float" - - # get model properties - context_size = self._get_context_size(model, credentials) - max_chunks = self._get_max_chunks(model, credentials) - - inputs = [] - indices = [] - used_tokens = 0 - - for i, text in enumerate(texts): - # Here token count is only an approximation based on the GPT2 tokenizer - # TODO: Optimize for better token estimation and chunking - num_tokens = self._get_num_tokens_by_gpt2(text) - - if num_tokens >= context_size: - cutoff = int(np.floor(len(text) * (context_size / num_tokens))) - # if num tokens is larger than context length, only use the start - inputs.append(text[0:cutoff]) - else: - inputs.append(text) - indices += [i] - - batched_embeddings = [] - _iter = range(0, len(inputs), max_chunks) - - for i in _iter: - # Prepare the payload for the request - payload = {"input": inputs[i : i + max_chunks], "model": model, **extra_model_kwargs} - - # Make the request to the OpenAI API - response = requests.post(endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) + credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/" - response.raise_for_status() # Raise an exception for HTTP errors - response_data = response.json() - - # Extract embeddings and used tokens from the response - embeddings_batch = [data["embedding"] for data in response_data["data"]] - embedding_used_tokens = response_data["usage"]["total_tokens"] - - used_tokens += embedding_used_tokens - batched_embeddings += embeddings_batch - - # calc usage - usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens) - - return TextEmbeddingResult(embeddings=batched_embeddings, usage=usage, model=model) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: - """ - Approximate number of tokens for given messages using GPT2 tokenizer - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - return sum(self._get_num_tokens_by_gpt2(text) for text in texts) + return OAICompatEmbeddingModel._invoke(self, model, credentials, texts, user, input_type) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -136,93 +44,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :param credentials: model credentials :return: """ - try: - headers = {"Content-Type": "application/json"} - - api_key = credentials.get("api_key") - - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - endpoint_url: Optional[str] - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - endpoint_url = "https://cloud.perfxlab.cn/v1/" - else: - endpoint_url = credentials.get("endpoint_url") - assert endpoint_url is not None, "endpoint_url is required in credentials" - if not endpoint_url.endswith("/"): - endpoint_url += "/" - - assert isinstance(endpoint_url, str) - endpoint_url = urljoin(endpoint_url, "embeddings") - - payload = {"input": "ping", "model": model} - - response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) - - if response.status_code != 200: - raise CredentialsValidateFailedError( - f"Credentials validation failed with status code {response.status_code}" - ) - - try: - json_result = response.json() - except json.JSONDecodeError as e: - raise CredentialsValidateFailedError("Credentials validation failed: JSON decode error") - - if "model" not in json_result: - raise CredentialsValidateFailedError("Credentials validation failed: invalid response") - except CredentialsValidateFailedError: - raise - except Exception as ex: - raise CredentialsValidateFailedError(str(ex)) - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: - """ - generate custom model entities from credentials - """ - entity = AIModelEntity( - model=model, - label=I18nObject(en_US=model), - model_type=ModelType.TEXT_EMBEDDING, - fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), - ModelPropertyKey.MAX_CHUNKS: 1, - }, - parameter_rules=[], - pricing=PriceConfig( - input=Decimal(credentials.get("input_price", 0)), - unit=Decimal(credentials.get("unit", 0)), - currency=credentials.get("currency", "USD"), - ), - ) - - return entity - - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param tokens: input tokens - :return: usage - """ - # get input price info - input_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens - ) - - # transform usage - usage = EmbeddingUsage( - tokens=tokens, - total_tokens=tokens, - unit_price=input_price_info.unit_price, - price_unit=input_price_info.unit, - total_price=input_price_info.total_amount, - currency=input_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) + if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": + credentials["endpoint_url"] = "https://cloud.perfxlab.cn/v1/" - return usage + OAICompatEmbeddingModel.validate_credentials(self, model, credentials)