Skip to content

Commit

Permalink
feat: support live TTS of fish audio (#555)
Browse files Browse the repository at this point in the history
* feat: support live TTS of fish audio

Signed-off-by: Frost Ming <me@frostming.com>
  • Loading branch information
frostming authored Oct 29, 2024
1 parent 7096933 commit 9e4a4b5
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 202 deletions.
131 changes: 58 additions & 73 deletions pdm.lock

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ edge-tts==6.1.10
exceptiongroup==1.2.0; python_version < "3.11"
frozenlist==1.4.1
google-ai-generativelanguage==0.6.10
google-api-core==2.15.0
google-api-core[grpc]==2.15.0
google-api-python-client==2.125.0
google-auth==2.26.1
google-auth-httplib2==0.2.0
Expand All @@ -37,7 +37,8 @@ grpcio-status==1.60.0
h11==0.14.0
httpcore==1.0.5
httplib2==0.22.0
httpx==0.27.2
httpx-ws==0.6.2
httpx[socks]==0.27.2
idna==3.7
jiter==0.5.0
jsonpatch==1.33
Expand Down Expand Up @@ -83,13 +84,13 @@ socksio==1.0.0
soupsieve==2.5
sqlalchemy==2.0.25
tenacity==8.2.3
tetos==0.3.1
tetos==0.4.1
tqdm==4.66.1
typing-extensions==4.12.2
typing-inspect==0.9.0
uritemplate==4.1.1
urllib3==2.1.0
websocket-client==1.8.0
websockets==12.0
wsproto==1.2.0
yarl==1.14.0
zhipuai==2.1.5.20230904
5 changes: 3 additions & 2 deletions xiaogpt/tts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from xiaogpt.tts.base import TTS
from xiaogpt.tts.file import TetosFileTTS
from xiaogpt.tts.live import TetosLiveTTS
from xiaogpt.tts.mi import MiTTS
from xiaogpt.tts.tetos import TetosTTS

__all__ = ["TTS", "TetosTTS", "MiTTS"]
__all__ = ["TTS", "TetosFileTTS", "MiTTS", "TetosLiveTTS"]
91 changes: 1 addition & 90 deletions xiaogpt/tts/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,10 @@

import abc
import asyncio
import functools
import json
import logging
import os
import random
import socket
import tempfile
import threading
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import TYPE_CHECKING, AsyncIterator

from xiaogpt.utils import get_hostname

if TYPE_CHECKING:
from typing import TypeVar

Expand Down Expand Up @@ -46,7 +36,7 @@ async def wait_for_duration(self, duration: float) -> None:
break
await asyncio.sleep(1)

async def get_if_xiaoai_is_playing(self):
async def get_if_xiaoai_is_playing(self) -> bool:
playing_info = await self.mina_service.player_get_status(self.device_id)
# WTF xiaomi api
is_playing = (
Expand All @@ -59,82 +49,3 @@ async def get_if_xiaoai_is_playing(self):
async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
"""Synthesize speech from a stream of text."""
raise NotImplementedError


class HTTPRequestHandler(SimpleHTTPRequestHandler):
def log_message(self, format, *args):
logger.debug(f"{self.address_string()} - {format}", *args)

def log_error(self, format, *args):
logger.error(f"{self.address_string()} - {format}", *args)

def copyfile(self, source, outputfile):
try:
super().copyfile(source, outputfile)
except (socket.error, ConnectionResetError, BrokenPipeError):
# ignore this or TODO find out why the error later
pass


class AudioFileTTS(TTS):
"""A TTS model that generates audio files locally and plays them via URL."""

def __init__(
self, mina_service: MiNAService, device_id: str, config: Config
) -> None:
super().__init__(mina_service, device_id, config)
self.dirname = tempfile.TemporaryDirectory(prefix="xiaogpt-tts-")
self._start_http_server()

@abc.abstractmethod
async def make_audio_file(self, lang: str, text: str) -> tuple[Path, float]:
"""Synthesize speech from text and save it to a file.
Return the file path and the duration of the audio in seconds.
The file path must be relative to the self.dirname.
"""
raise NotImplementedError

async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
queue: asyncio.Queue[tuple[str, float]] = asyncio.Queue()
finished = asyncio.Event()

async def worker():
async for text in text_stream:
path, duration = await self.make_audio_file(lang, text)
url = f"http://{self.hostname}:{self.port}/{path.name}"
await queue.put((url, duration))
finished.set()

task = asyncio.create_task(worker())

while True:
try:
url, duration = queue.get_nowait()
except asyncio.QueueEmpty:
if finished.is_set():
break
else:
await asyncio.sleep(0.1)
continue
logger.debug("Playing URL %s (%s seconds)", url, duration)
await asyncio.gather(
self.mina_service.play_by_url(self.device_id, url, _type=1),
self.wait_for_duration(duration),
)
await task

def _start_http_server(self):
# set the port range
port_range = range(8050, 8090)
# get a random port from the range
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
# create the server
handler = functools.partial(HTTPRequestHandler, directory=self.dirname.name)
httpd = ThreadingHTTPServer(("", self.port), handler)
# start the server in a new thread
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.daemon = True
server_thread.start()

self.hostname = get_hostname()
logger.info(f"Serving on {self.hostname}:{self.port}")
103 changes: 103 additions & 0 deletions xiaogpt/tts/file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import asyncio
import functools
import os
import random
import socket
import tempfile
import threading
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from typing import AsyncIterator

from miservice import MiNAService

from xiaogpt.config import Config
from xiaogpt.tts.base import TTS, logger
from xiaogpt.utils import get_hostname


class HTTPRequestHandler(SimpleHTTPRequestHandler):
def log_message(self, format, *args):
logger.debug(f"{self.address_string()} - {format}", *args)

def log_error(self, format, *args):
logger.error(f"{self.address_string()} - {format}", *args)

def copyfile(self, source, outputfile):
try:
super().copyfile(source, outputfile)
except (socket.error, ConnectionResetError, BrokenPipeError):
# ignore this or TODO find out why the error later
pass


class TetosFileTTS(TTS):
"""A TTS model that generates audio files locally and plays them via URL."""

def __init__(
self, mina_service: MiNAService, device_id: str, config: Config
) -> None:
from tetos import get_speaker

super().__init__(mina_service, device_id, config)
self.dirname = tempfile.TemporaryDirectory(prefix="xiaogpt-tts-")
self._start_http_server()

assert config.tts and config.tts != "mi"
speaker_cls = get_speaker(config.tts)
try:
self.speaker = speaker_cls(**config.tts_options)
except TypeError as e:
raise ValueError(f"{e}. Please add them via `tts_options` config") from e

async def make_audio_file(self, lang: str, text: str) -> tuple[Path, float]:
output_file = tempfile.NamedTemporaryFile(
suffix=".mp3", mode="wb", delete=False, dir=self.dirname.name
)
duration = await self.speaker.synthesize(text, output_file.name, lang=lang)
return Path(output_file.name), duration

async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
queue: asyncio.Queue[tuple[str, float]] = asyncio.Queue()
finished = asyncio.Event()

async def worker():
async for text in text_stream:
path, duration = await self.make_audio_file(lang, text)
url = f"http://{self.hostname}:{self.port}/{path.name}"
await queue.put((url, duration))
finished.set()

task = asyncio.create_task(worker())

while True:
try:
url, duration = queue.get_nowait()
except asyncio.QueueEmpty:
if finished.is_set():
break
else:
await asyncio.sleep(0.1)
continue
logger.debug("Playing URL %s (%s seconds)", url, duration)
await asyncio.gather(
self.mina_service.play_by_url(self.device_id, url, _type=1),
self.wait_for_duration(duration),
)
await task

def _start_http_server(self):
# set the port range
port_range = range(8050, 8090)
# get a random port from the range
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
# create the server
handler = functools.partial(HTTPRequestHandler, directory=self.dirname.name)
httpd = ThreadingHTTPServer(("", self.port), handler)
# start the server in a new thread
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.daemon = True
server_thread.start()

self.hostname = get_hostname()
logger.info(f"Serving on {self.hostname}:{self.port}")
98 changes: 98 additions & 0 deletions xiaogpt/tts/live.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import asyncio
import os
import queue
import random
import threading
import uuid
from functools import lru_cache
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import AsyncIterator

from miservice import MiNAService

from xiaogpt.config import Config
from xiaogpt.tts.base import TTS, logger
from xiaogpt.utils import get_hostname


@lru_cache(maxsize=64)
def get_queue(key: str) -> queue.Queue[bytes]:
return queue.Queue()


class HTTPRequestHandler(BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "audio/mpeg")
self.end_headers()
key = self.path.split("/")[-1]
queue = get_queue(key)
while True:
chunk = queue.get()
if chunk == b"":
break
self.wfile.write(chunk)

def log_message(self, format, *args):
logger.debug(f"{self.address_string()} - {format}", *args)

def log_error(self, format, *args):
logger.error(f"{self.address_string()} - {format}", *args)


class TetosLiveTTS(TTS):
"""A TTS model that generates audio in real-time."""

def __init__(
self, mina_service: MiNAService, device_id: str, config: Config
) -> None:
from tetos import get_speaker

super().__init__(mina_service, device_id, config)
self._start_http_server()

assert config.tts and config.tts != "mi"
speaker_cls = get_speaker(config.tts)
try:
self.speaker = speaker_cls(**config.tts_options)
except TypeError as e:
raise ValueError(f"{e}. Please add them via `tts_options` config") from e
if not hasattr(self.speaker, "live"):
raise ValueError(f"{config.tts} Speaker does not support live synthesis")

async def synthesize(self, lang: str, text_stream: AsyncIterator[str]) -> None:
key = str(uuid.uuid4())
queue = get_queue(key)

async def worker():
async for chunk in self.speaker.live(text_stream, lang):
queue.put(chunk)
queue.put(b"")

task = asyncio.create_task(worker())
await self.mina_service.play_by_url(
self.device_id, f"http://{self.hostname}:{self.port}/{key}", _type=1
)

while True:
if await self.get_if_xiaoai_is_playing():
await asyncio.sleep(1)
else:
break
await task

def _start_http_server(self):
# set the port range
port_range = range(8050, 8090)
# get a random port from the range
self.port = int(os.getenv("XIAOGPT_PORT", random.choice(port_range)))
# create the server
handler = HTTPRequestHandler
httpd = ThreadingHTTPServer(("", self.port), handler)
# start the server in a new thread
server_thread = threading.Thread(target=httpd.serve_forever)
server_thread.daemon = True
server_thread.start()

self.hostname = get_hostname()
logger.info(f"Serving on {self.hostname}:{self.port}")
31 changes: 0 additions & 31 deletions xiaogpt/tts/tetos.py

This file was deleted.

Loading

0 comments on commit 9e4a4b5

Please sign in to comment.