-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor the code and add
tti
module
- Loading branch information
1 parent
f119cee
commit 26ba6e4
Showing
14 changed files
with
251 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
|
||
class SparkConfig(BaseSettings): | ||
app_id: str | ||
api_key: str | ||
api_secret: str | ||
api_model: str | ||
|
||
model_config = SettingsConfigDict( | ||
env_file='~/.sparkapi.env', | ||
env_prefix='SPARK_', | ||
case_sensitive=False, | ||
) | ||
|
||
|
||
class ChatConfig(BaseSettings): | ||
temperature: float = 0.5 | ||
max_tokens: int = 2048 | ||
top_k: int = 4 | ||
|
||
model_config = SettingsConfigDict( | ||
env_file='~/.sparkapi.env', | ||
env_prefix='SPARK_CHAT_', | ||
case_sensitive=False, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
print(1, SparkConfig(_env_file='.env')) | ||
print(2, ChatConfig(_env_file='.env')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import base64 | ||
from io import BytesIO | ||
|
||
import requests | ||
from PIL import Image | ||
|
||
from sparkapi import MODELS as MODEL_MAP | ||
from sparkapi.util import get_auth_url | ||
from .query import QueryParams | ||
|
||
|
||
class SparkAPI(object): | ||
def __init__(self, app_id: str, api_key: str, api_secret: str, api_model: str, **kwargs): | ||
self.app_id = app_id | ||
self.api_key = api_key | ||
self.api_secret = api_secret | ||
self.api_model = api_model | ||
self._auth_url = None | ||
|
||
@property | ||
def auth_url(self): | ||
if self._auth_url is None: | ||
api_url = MODEL_MAP[self.api_model]['url'] | ||
self._auth_url = get_auth_url(api_url, self.api_secret, self.api_key, method='POST') | ||
return self._auth_url | ||
|
||
def build_query(self, messages, **kwargs): | ||
query = QueryParams( | ||
app_id=self.app_id, | ||
domain=MODEL_MAP[self.api_model]['domain'], | ||
text=messages, | ||
**kwargs | ||
) | ||
return query.dump() | ||
|
||
def get_completion(self, prompt: str, outfile=None, **kwargs): | ||
"""get completion from prompt | ||
注: 文生图目前仅开放单轮交互,单轮交互只需要传递一个user角色的数据 | ||
""" | ||
messages = [{'role': 'user', 'content': prompt}] | ||
query = self.build_query(messages, **kwargs) | ||
|
||
response = requests.post(self.auth_url, json=query, headers={'content-type': "application/json"}) | ||
data = response.json() | ||
|
||
if data['header']['code'] != 0: | ||
print(f'[ERROR]:{data}') | ||
return None | ||
|
||
image_data = data['payload']['choices']['text'][0]['content'] | ||
|
||
if outfile: | ||
return self.save_image(image_data, outfile) | ||
|
||
return data | ||
|
||
@staticmethod | ||
def save_image(image_data, outfile): | ||
if outfile == 'data_url': | ||
uri = f'data://image/png;base64,{image_data}' | ||
return uri | ||
im = Image.open(BytesIO(base64.b64decode(image_data))) | ||
im.save(outfile) | ||
return outfile | ||
|
||
|
||
if __name__ == '__main__': | ||
from sparkapi.core.tti.api import SparkAPI | ||
from sparkapi.config import SparkConfig | ||
api = SparkAPI(**SparkConfig(api_model='tti').model_dump()) | ||
# res = api.get_completion('帮我生成一张二次元风景图', outfile='data_url') | ||
res = api.get_completion('帮我生成一张二次元风景图', outfile='out.png') | ||
print(res) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
|
||
class SparkConfig(BaseSettings): | ||
app_id: str | ||
api_key: str | ||
api_secret: str | ||
api_model: str = 'tti' | ||
|
||
model_config = SettingsConfigDict( | ||
env_file='~/.sparkapi.env', | ||
env_prefix='SPARK_', | ||
case_sensitive=False, | ||
) | ||
|
||
|
||
class ChatConfig(BaseSettings): | ||
temperature: float = 0.5 | ||
max_tokens: int = 2048 | ||
top_k: int = 4 | ||
|
||
model_config = SettingsConfigDict( | ||
env_file='~/.sparkapi.env', | ||
env_prefix='SPARK_CHAT_', | ||
case_sensitive=False, | ||
) | ||
|
||
|
||
if __name__ == '__main__': | ||
print(1, SparkConfig(_env_file='.env')) | ||
print(2, ChatConfig(_env_file='.env')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import json | ||
from enum import Enum | ||
from typing import Optional, List | ||
from dataclasses import dataclass, asdict | ||
|
||
|
||
class Role(Enum): | ||
USER = 'user' | ||
|
||
def __str__(self): | ||
return self.value | ||
|
||
|
||
@dataclass | ||
class Text: | ||
role: Role | ||
content: str | ||
|
||
|
||
@dataclass | ||
class QueryParams: | ||
app_id: str | ||
text: List[Text] | ||
domain: Text | ||
|
||
width: int = 512 | ||
height: int = 512 | ||
|
||
uid: Optional[str] = None | ||
|
||
temperature: float = 0.5 | ||
max_tokens: int = 2048 | ||
top_k: int = 1 | ||
chat_id: Optional[str] = None | ||
|
||
def dump(self): | ||
text = [ | ||
{'role': str(t.role), 'content': t.content} | ||
if isinstance(t, Text) else t | ||
for t in self.text | ||
] | ||
data = { | ||
'header': { | ||
'app_id': self.app_id, | ||
'uid': self.uid, | ||
}, | ||
'parameter': { | ||
'chat': { | ||
'domain': str(self.domain), | ||
'max_tokens': self.max_tokens, | ||
'temperature': self.temperature, | ||
'top_k': self.top_k, | ||
'chat_id': self.chat_id, | ||
} | ||
}, | ||
'payload': { | ||
'message': { | ||
'text': text | ||
} | ||
} | ||
} | ||
return data | ||
|
||
def dump_json(self): | ||
return json.dumps(self.dump(), ensure_ascii=False) | ||
|
||
|
||
if __name__ == '__main__': | ||
from pprint import pprint | ||
params = QueryParams( | ||
app_id='app_id', | ||
domain='general', | ||
# text=[Text(role=Role.USER, content='hello')], | ||
text=[{'role': 'user', 'content': 'hello'}], | ||
) | ||
pprint(params.dump()) | ||
pprint(params.dump_json()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .auth import get_wss_url | ||
from .auth import get_auth_url | ||
from .common import generate_rfc1123_date |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters