Skip to content

Commit

Permalink
refactor the code and add tti module
Browse files Browse the repository at this point in the history
  • Loading branch information
suqingdong committed Mar 4, 2024
1 parent f119cee commit 26ba6e4
Show file tree
Hide file tree
Showing 14 changed files with 251 additions and 23 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ sparkapi prompt 详细介绍一下科大讯飞,输出Markdown结果

### Python
```python
from sparkapi.core.api import SparkAPI
from sparkapi.core.config import SparkConfig
from sparkapi.core.chat.api import SparkAPI
from sparkapi.core.chat.config import SparkConfig
config = SparkConfig().model_dump()
api = SparkAPI(**config)

Expand Down Expand Up @@ -70,3 +70,9 @@ print(''.join(res))

#### [1.0.3] - 2023-10-26
- Add support for model `v3.0`


### ToDO
- Function Call
- 图片生成
- 图片理解
4 changes: 2 additions & 2 deletions sparkapi/bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


from sparkapi import version_info, MODELS
from sparkapi.core.api import SparkAPI
from sparkapi.core.config import SparkConfig, ChatConfig
from sparkapi.core.chat.api import SparkAPI
from sparkapi.core.chat.config import SparkConfig, ChatConfig


CONTEXT_SETTINGS = dict(help_option_names=['-?', '-h', '--help'])
Expand Down
31 changes: 31 additions & 0 deletions sparkapi/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings, SettingsConfigDict


class SparkConfig(BaseSettings):
app_id: str
api_key: str
api_secret: str
api_model: str

model_config = SettingsConfigDict(
env_file='~/.sparkapi.env',
env_prefix='SPARK_',
case_sensitive=False,
)


class ChatConfig(BaseSettings):
temperature: float = 0.5
max_tokens: int = 2048
top_k: int = 4

model_config = SettingsConfigDict(
env_file='~/.sparkapi.env',
env_prefix='SPARK_CHAT_',
case_sensitive=False,
)


if __name__ == '__main__':
print(1, SparkConfig(_env_file='.env'))
print(2, ChatConfig(_env_file='.env'))
10 changes: 10 additions & 0 deletions sparkapi/config/models.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,15 @@
"v3.5": {
"domain": "generalv3.5",
"url": "wss://spark-api.xf-yun.com/v3.5/chat"
},
"tti": {
"domain": "general",
"url": "https://spark-api.cn-huabei-1.xf-yun.com/v2.1/tti",
"desc": "图像生成"
},
"image": {
"domain": "general",
"url": "wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image",
"desc": "图像理解"
}
}
Empty file added sparkapi/core/chat/__init__.py
Empty file.
11 changes: 9 additions & 2 deletions sparkapi/core/api.py → sparkapi/core/chat/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from websockets.sync.client import connect as ws_connect

from sparkapi import MODELS as MODEL_MAP
from sparkapi.util import get_wss_url
from sparkapi.util import get_auth_url
from .query import QueryParams


Expand All @@ -20,7 +20,7 @@ def __init__(self, app_id: str, api_key: str, api_secret: str, api_model: str, *
def create_wss_connection(self):
if self._wss_url is None:
api_url = MODEL_MAP[self.api_model]['url']
self._wss_url = get_wss_url(api_url, self.api_secret, self.api_key)
self._wss_url = get_auth_url(api_url, self.api_secret, self.api_key)
return ws_connect(self._wss_url)

def build_query(self, messages, **kwargs):
Expand Down Expand Up @@ -67,3 +67,10 @@ def chat(self, **kwargs):

click.secho(f'>>> AI: {result}', fg='cyan', bold=True)
messages.append({'role': 'assistant', 'content': result})


if __name__ == '__main__':
from sparkapi.core.chat.api import SparkAPI
from sparkapi.config import SparkConfig
api = SparkAPI(**SparkConfig().model_dump())
print(''.join(api.get_completion('你是谁?')))
2 changes: 1 addition & 1 deletion sparkapi/core/config.py → sparkapi/core/chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class SparkConfig(BaseSettings):
app_id: str
api_key: str
api_secret: str
api_model: str = 'v1.5'
api_model: str

model_config = SettingsConfigDict(
env_file='~/.sparkapi.env',
Expand Down
14 changes: 3 additions & 11 deletions sparkapi/core/query.py → sparkapi/core/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,8 @@
from dataclasses import dataclass, asdict


class Domain(Enum):
GENERAL = 'general'
GENERAL_V2 = 'generalv2'
GENERAL_V3 = 'generalv3'

def __str__(self):
return self.value


class Role(Enum):
SYSTEM = 'system'
USER = 'user'
ASSISTANT = 'assistant'

Expand All @@ -31,7 +23,7 @@ class Text:
class QueryParams:
app_id: str
text: List[Text]
domain: Domain
domain: Text

uid: Optional[str] = None

Expand Down Expand Up @@ -76,7 +68,7 @@ def dump_json(self):
from pprint import pprint
params = QueryParams(
app_id='app_id',
domain=Domain.GENERAL,
domain='general',
# text=[Text(role=Role.USER, content='hello')],
text=[{'role': 'user', 'content': 'hello'}],
)
Expand Down
Empty file added sparkapi/core/tti/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions sparkapi/core/tti/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import base64
from io import BytesIO

import requests
from PIL import Image

from sparkapi import MODELS as MODEL_MAP
from sparkapi.util import get_auth_url
from .query import QueryParams


class SparkAPI(object):
def __init__(self, app_id: str, api_key: str, api_secret: str, api_model: str, **kwargs):
self.app_id = app_id
self.api_key = api_key
self.api_secret = api_secret
self.api_model = api_model
self._auth_url = None

@property
def auth_url(self):
if self._auth_url is None:
api_url = MODEL_MAP[self.api_model]['url']
self._auth_url = get_auth_url(api_url, self.api_secret, self.api_key, method='POST')
return self._auth_url

def build_query(self, messages, **kwargs):
query = QueryParams(
app_id=self.app_id,
domain=MODEL_MAP[self.api_model]['domain'],
text=messages,
**kwargs
)
return query.dump()

def get_completion(self, prompt: str, outfile=None, **kwargs):
"""get completion from prompt
注: 文生图目前仅开放单轮交互,单轮交互只需要传递一个user角色的数据
"""
messages = [{'role': 'user', 'content': prompt}]
query = self.build_query(messages, **kwargs)

response = requests.post(self.auth_url, json=query, headers={'content-type': "application/json"})
data = response.json()

if data['header']['code'] != 0:
print(f'[ERROR]:{data}')
return None

image_data = data['payload']['choices']['text'][0]['content']

if outfile:
return self.save_image(image_data, outfile)

return data

@staticmethod
def save_image(image_data, outfile):
if outfile == 'data_url':
uri = f'data://image/png;base64,{image_data}'
return uri
im = Image.open(BytesIO(base64.b64decode(image_data)))
im.save(outfile)
return outfile


if __name__ == '__main__':
from sparkapi.core.tti.api import SparkAPI
from sparkapi.config import SparkConfig
api = SparkAPI(**SparkConfig(api_model='tti').model_dump())
# res = api.get_completion('帮我生成一张二次元风景图', outfile='data_url')
res = api.get_completion('帮我生成一张二次元风景图', outfile='out.png')
print(res)

31 changes: 31 additions & 0 deletions sparkapi/core/tti/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pydantic_settings import BaseSettings, SettingsConfigDict


class SparkConfig(BaseSettings):
app_id: str
api_key: str
api_secret: str
api_model: str = 'tti'

model_config = SettingsConfigDict(
env_file='~/.sparkapi.env',
env_prefix='SPARK_',
case_sensitive=False,
)


class ChatConfig(BaseSettings):
temperature: float = 0.5
max_tokens: int = 2048
top_k: int = 4

model_config = SettingsConfigDict(
env_file='~/.sparkapi.env',
env_prefix='SPARK_CHAT_',
case_sensitive=False,
)


if __name__ == '__main__':
print(1, SparkConfig(_env_file='.env'))
print(2, ChatConfig(_env_file='.env'))
77 changes: 77 additions & 0 deletions sparkapi/core/tti/query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import json
from enum import Enum
from typing import Optional, List
from dataclasses import dataclass, asdict


class Role(Enum):
USER = 'user'

def __str__(self):
return self.value


@dataclass
class Text:
role: Role
content: str


@dataclass
class QueryParams:
app_id: str
text: List[Text]
domain: Text

width: int = 512
height: int = 512

uid: Optional[str] = None

temperature: float = 0.5
max_tokens: int = 2048
top_k: int = 1
chat_id: Optional[str] = None

def dump(self):
text = [
{'role': str(t.role), 'content': t.content}
if isinstance(t, Text) else t
for t in self.text
]
data = {
'header': {
'app_id': self.app_id,
'uid': self.uid,
},
'parameter': {
'chat': {
'domain': str(self.domain),
'max_tokens': self.max_tokens,
'temperature': self.temperature,
'top_k': self.top_k,
'chat_id': self.chat_id,
}
},
'payload': {
'message': {
'text': text
}
}
}
return data

def dump_json(self):
return json.dumps(self.dump(), ensure_ascii=False)


if __name__ == '__main__':
from pprint import pprint
params = QueryParams(
app_id='app_id',
domain='general',
# text=[Text(role=Role.USER, content='hello')],
text=[{'role': 'user', 'content': 'hello'}],
)
pprint(params.dump())
pprint(params.dump_json())
2 changes: 1 addition & 1 deletion sparkapi/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .auth import get_wss_url
from .auth import get_auth_url
from .common import generate_rfc1123_date
7 changes: 3 additions & 4 deletions sparkapi/util/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .common import generate_rfc1123_date


def get_wss_url(api_url, api_secret, api_key):
def get_auth_url(api_url, api_secret, api_key, method='GET'):
"""
Generate auth params for API request.
"""
Expand All @@ -20,7 +20,7 @@ def get_wss_url(api_url, api_secret, api_key):
signature_origin = textwrap.dedent(f'''
host: {api_host}
date: {rfc1123_date}
GET {api_path} HTTP/1.1
{method} {api_path} HTTP/1.1
''').strip()
signature_sha = hmac.new(
api_secret.encode(),
Expand All @@ -39,12 +39,11 @@ def get_wss_url(api_url, api_secret, api_key):
authorization_origin = ', '.join(f'{k}="{v}"' for k, v in authorization_payload.items())
authorization = base64.b64encode(authorization_origin.encode()).decode()

# step3: generate wss url
# step3: generate auth url
payload = {
'authorization': authorization,
'date': rfc1123_date,
'host': api_host
}
url = api_url + '?' + urlencode(payload)
# print(f'wss url: {url}')
return url

0 comments on commit 26ba6e4

Please sign in to comment.