Skip to content

Commit

Permalink
fix: lazy import bot-specific modules (#439)
Browse files Browse the repository at this point in the history
* fix: lazy import bot-specific modules

Signed-off-by: Frost Ming <me@frostming.com>

* Trigger Build

---------

Signed-off-by: Frost Ming <me@frostming.com>
Co-authored-by: yihong0618 <zouzou0208@gmail.com>
  • Loading branch information
frostming and yihong0618 authored Jan 25, 2024
1 parent 11b5c3a commit 5a02fdc
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 18 deletions.
3 changes: 2 additions & 1 deletion xiaogpt/bot/bard_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing import Any

from bardapi import BardAsync
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
Expand All @@ -16,6 +15,8 @@ def __init__(
self,
bard_token: str,
) -> None:
from bardapi import BardAsync

self._bot = BardAsync(token=bard_token)
self.history = []

Expand Down
8 changes: 6 additions & 2 deletions xiaogpt/bot/chatgptapi_bot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import dataclasses
from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar

import httpx
import openai
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.utils import split_sentences

if TYPE_CHECKING:
import openai


@dataclasses.dataclass
class ChatGPTBot(ChatHistoryMixin, BaseBot):
Expand All @@ -22,6 +24,8 @@ class ChatGPTBot(ChatHistoryMixin, BaseBot):
history: list[tuple[str, str]] = dataclasses.field(default_factory=list, init=False)

def _make_openai_client(self, sess: httpx.AsyncClient) -> openai.AsyncOpenAI:
import openai

if self.api_base and self.api_base.rstrip("/").endswith("openai.azure.com"):
return openai.AsyncAzureOpenAI(
api_key=self.openai_key,
Expand Down
4 changes: 2 additions & 2 deletions xiaogpt/bot/gemini_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Any

from rich import print
import google.generativeai as genai
from google.generativeai.types.generation_types import StopCandidateException

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin

Expand Down Expand Up @@ -34,6 +32,8 @@ class GeminiBot(ChatHistoryMixin, BaseBot):
name = "Gemini"

def __init__(self, gemini_key: str) -> None:
import google.generativeai as genai

genai.configure(api_key=gemini_key)
self.history = []
model = genai.GenerativeModel(
Expand Down
5 changes: 4 additions & 1 deletion xiaogpt/bot/glm_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from typing import Any

import zhipuai
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
Expand All @@ -14,6 +13,8 @@ class GLMBot(ChatHistoryMixin, BaseBot):
default_options = {"model": "chatglm_turbo"}

def __init__(self, glm_key: str) -> None:
import zhipuai

self.history = []
zhipuai.api_key = glm_key

Expand All @@ -22,6 +23,8 @@ def from_config(cls, config):
return cls(glm_key=config.glm_key)

def ask(self, query, **options):
import zhipuai

ms = self.get_messages()
kwargs = {**self.default_options, **options}
kwargs["prompt"] = ms
Expand Down
5 changes: 4 additions & 1 deletion xiaogpt/bot/gpt3_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import ClassVar

import httpx
import openai
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
Expand All @@ -26,6 +25,8 @@ def from_config(cls, config):
)

async def ask(self, query, **options):
import openai

data = {
"prompt": query,
"model": "text-davinci-003",
Expand All @@ -50,6 +51,8 @@ async def ask(self, query, **options):
return completion.choices[0].text

async def ask_stream(self, query, **options):
import openai

data = {
"prompt": query,
"model": "text-davinci-003",
Expand Down
8 changes: 6 additions & 2 deletions xiaogpt/bot/newbing_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import re

from EdgeGPT import Chatbot, ConversationStyle

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
from xiaogpt.utils import split_sentences

Expand All @@ -19,6 +17,8 @@ def __init__(
bing_cookies: dict | None = None,
proxy: str | None = None,
):
from EdgeGPT import Chatbot

self.history = []
self._bot = Chatbot(
cookiePath=bing_cookie_path, cookies=bing_cookies, proxy=proxy
Expand All @@ -40,6 +40,8 @@ def clean_text(s):
return s.strip()

async def ask(self, query, **options):
from EdgeGPT import ConversationStyle

kwargs = {"conversation_style": ConversationStyle.balanced, **options}
completion = await self._bot.ask(prompt=query, **kwargs)
try:
Expand All @@ -51,6 +53,8 @@ async def ask(self, query, **options):
return text

async def ask_stream(self, query, **options):
from EdgeGPT import ConversationStyle

kwargs = {"conversation_style": ConversationStyle.balanced, **options}
try:
completion = self._bot.ask_stream(prompt=query, **kwargs)
Expand Down
14 changes: 10 additions & 4 deletions xiaogpt/bot/qwen_bot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
"""ChatGLM bot"""
from __future__ import annotations

from http import HTTPStatus
from typing import Any

from http import HTTPStatus
import dashscope
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role
from rich import print

from xiaogpt.bot.base_bot import BaseBot, ChatHistoryMixin
Expand All @@ -16,6 +13,9 @@ class QwenBot(ChatHistoryMixin, BaseBot):
name = "Qian Wen"

def __init__(self, qwen_key: str) -> None:
import dashscope
from dashscope.api_entities.dashscope_response import Role

self.history = [
{"role": Role.SYSTEM, "content": "You are a helpful assistant."}
]
Expand All @@ -26,6 +26,9 @@ def from_config(cls, config):
return cls(qwen_key=config.qwen_key)

async def ask(self, query, **options):
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role

# from https://help.aliyun.com/zh/dashscope/developer-reference/api-details
self.history.append({"role": Role.USER, "content": query})

Expand Down Expand Up @@ -61,6 +64,9 @@ async def ask(self, query, **options):
return "没有返回"

async def ask_stream(self, query: str, **options: Any):
from dashscope import Generation
from dashscope.api_entities.dashscope_response import Role

self.history.append({"role": Role.USER, "content": query})
responses = Generation.call(
Generation.Models.qwen_turbo,
Expand Down
10 changes: 6 additions & 4 deletions xiaogpt/langchain/examples/email/mail_box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import imaplib
import email
from datetime import datetime, timedelta
import html
from bs4 import BeautifulSoup
import imaplib
import re
import openai
import smtplib
from datetime import datetime, timedelta
from email.mime.text import MIMEText

from bs4 import BeautifulSoup


class Mailbox:
# Gmail account settings need to be configured
Expand Down Expand Up @@ -115,6 +115,8 @@ def get_email_content(self, mailbox, email_id):
return ""

def get_summary_by_ai(self, email_content: str, prompt: str) -> str:
import openai

print("Asking AI to summarize email content...")

# Request ChatGPT for summary
Expand Down
7 changes: 6 additions & 1 deletion xiaogpt/tts/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import tempfile
from pathlib import Path
from typing import TYPE_CHECKING

import httpx
import openai

from xiaogpt.tts.base import AudioFileTTS
from xiaogpt.utils import calculate_tts_elapse

if TYPE_CHECKING:
import openai


class OpenAITTS(AudioFileTTS):
default_voice = "alloy"
Expand All @@ -32,6 +35,8 @@ async def make_audio_file(self, query: str, text: str) -> tuple[Path, float]:
return Path(output_file.name), calculate_tts_elapse(text)

def _make_openai_client(self, sess: httpx.AsyncClient) -> openai.AsyncOpenAI:
import openai

api_base = self.config.api_base
if api_base and api_base.rstrip("/").endswith("openai.azure.com"):
raise NotImplementedError("TTS is not supported for Azure OpenAI")
Expand Down

0 comments on commit 5a02fdc

Please sign in to comment.