Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add WebSocket server with multi-client support #263

Open
wants to merge 37 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
467997d
Add Test Suite (#237)
juanmc2005 May 25, 2024
43135ae
Bump up version to 0.9.1
juanmc2005 May 25, 2024
1b1bf73
Limit numpy version to < 2.0.0 (#243)
juanmc2005 Jun 28, 2024
ed77f7f
Fix embedding extraction example in README (#241)
hmehdi515 Jul 8, 2024
da080e7
add handler for multiple websocket streams
janaab11 Nov 21, 2024
e5fe109
move server instance to inference handler
janaab11 Nov 21, 2024
12df31b
update serve script to use inference handler
janaab11 Nov 21, 2024
a709d58
debug server run
janaab11 Nov 21, 2024
b9b1780
turn off progress bar for multiple connections
janaab11 Nov 21, 2024
c8127f7
add error handling for inference server
janaab11 Nov 22, 2024
b805575
refactor streaming server for readability
janaab11 Nov 22, 2024
1b4e9af
fix memory leak when clients share Pipeline instance
janaab11 Nov 27, 2024
b897991
reset Pipeline state when client disconnects
janaab11 Nov 27, 2024
d2566f1
test READY message to client after init
janaab11 Nov 27, 2024
f9f2c80
update client to stream audio after READY message
janaab11 Nov 29, 2024
1db7a8f
test CLOSE message to client after disconnect
janaab11 Nov 29, 2024
80ee316
added Dockerfile
janaab11 Dec 20, 2024
fb9fecf
expose custom params for Docker
janaab11 Dec 20, 2024
1975f75
apply styling with black and isort
janaab11 Dec 22, 2024
7ba2f55
refactor StreamingHandler to use LazyModel for resource mgmt
janaab11 Dec 31, 2024
3d3bb45
simplified websocket-server class and improved naming
janaab11 Jan 2, 2025
b2c9293
improved code quality and style
janaab11 Jan 2, 2025
bba43ae
updated Dockerfile for local builds and reduced image size
janaab11 Jan 3, 2025
f95d6ca
refactor close and send methods of WebSocketStreamingServer to separa…
janaab11 Jan 3, 2025
c6a2f92
refactor: improve error handling and reduce redundancy
janaab11 Jan 3, 2025
12f7ba9
refactor: make WebSocketAudioSource a proxy and handle audio decoding…
janaab11 Jan 3, 2025
0f6ac4e
apply styling with black and isort
janaab11 Jan 3, 2025
c88fbd7
correct styling with isort
janaab11 Jan 3, 2025
eb200c2
cleanup: remove deprecated output argument
janaab11 Jan 3, 2025
16dbbc0
fix typo in dockerignore
janaab11 Jan 3, 2025
dd1ff48
fix(websockets): update socket error handling
janaab11 Jan 3, 2025
e1db50d
fix(websockets): improve error logging for edge-cases
janaab11 Jan 3, 2025
d944e6f
fix(websockets): add retry backoff to server
janaab11 Jan 3, 2025
f2c3144
apply styling with black and isort
janaab11 Jan 3, 2025
b6d6bc6
fix(client): manage stop events and handle errors correctly
janaab11 Jan 3, 2025
74bd40b
fix(client): improve error handling and readability
janaab11 Jan 3, 2025
c263239
add documentation for updated websocket server and dockerfile usage
janaab11 Jan 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Development
.git/
.github/
.idea/
__pycache__/

# Data and examples
assets/
example/
expected_outputs/
tests/

# Documentation
docs/

# Build artifacts
*.egg-info/
dist/
build/
35 changes: 35 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Pytest

on:
pull_request:
branches:
- main
- develop

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install apt dependencies
run: |
sudo add-apt-repository ppa:savoury1/ffmpeg4
sudo apt-get update
sudo apt-get -y install ffmpeg libportaudio2=19.6.0-1.1

- name: Install pip dependencies
run: |
python -m pip install --upgrade pip
pip install .[tests]

- name: Run tests
run: |
pytest
7 changes: 4 additions & 3 deletions .github/workflows/quick-runs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
pip install onnxruntime==1.18.0
- name: Crop audio and rttm
run: |
sox audio/ES2002a_long.wav audio/ES2002a.wav trim 00:40 00:30
Expand All @@ -50,10 +51,10 @@ jobs:
rm rttms/ES2002b_long.rttm
- name: Run stream
run: |
diart.stream audio/ES2002a.wav --output trash --no-plot --hf-token ${{ secrets.HUGGINGFACE }}
diart.stream audio/ES2002a.wav --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx --output trash --no-plot
- name: Run benchmark
run: |
diart.benchmark audio --reference rttms --batch-size 4 --hf-token ${{ secrets.HUGGINGFACE }}
diart.benchmark audio --reference rttms --batch-size 4 --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
- name: Run tuning
run: |
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --hf-token ${{ secrets.HUGGINGFACE }}
diart.tune audio --reference rttms --batch-size 4 --num-iter 2 --output trash --segmentation assets/models/segmentation_uint8.onnx --embedding assets/models/embedding_uint8.onnx
75 changes: 75 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Use NVIDIA CUDA base image
FROM docker.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04

# Install sudo, git, wget, gcc, g++, and other essential build tools
RUN apt-get update && \
apt-get install -y sudo git wget build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Install Miniconda
ENV CONDA_DIR=/opt/conda
ENV PATH=$CONDA_DIR/bin:$PATH
RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/miniconda.sh && \
bash /tmp/miniconda.sh -b -p $CONDA_DIR && \
rm /tmp/miniconda.sh

# Install Python 3.10 using Conda
RUN conda install python=3.10

# Upgrade pip and setuptools to avoid deprecation warnings
RUN pip install --upgrade pip setuptools

# Set Python 3.11 as default by creating a symbolic link
RUN ln -sf /opt/conda/bin/python3.10 /opt/conda/bin/python && \
ln -sf /opt/conda/bin/python3.10 /usr/bin/python

# Verify installations
RUN python --version && \
gcc --version && \
g++ --version && \
pip --version && \
conda --version

# Create app directory and copy files
WORKDIR /diart
COPY . .

# Install diart dependencies
RUN conda install portaudio pysoundfile ffmpeg -c conda-forge
RUN pip install -e .

# Expose the port the app runs on
EXPOSE 7007

# Define environment variable to prevent Python from buffering stdout/stderr
# and writing byte code to file
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1

# Define custom options as env variables with defaults
ENV HOST=0.0.0.0
ENV PORT=7007
ENV SEGMENTATION=pyannote/segmentation-3.0
ENV EMBEDDING=speechbrain/spkrec-resnet-voxceleb
ENV TAU_ACTIVE=0.45
ENV RHO_UPDATE=0.25
ENV DELTA_NEW=0.6
ENV LATENCY=5
ENV MAX_SPEAKERS=3

CMD ["sh", "-c", "python -m diart.console.serve --host ${HOST} --port ${PORT} --segmentation ${SEGMENTATION} --embedding ${EMBEDDING} --tau-active ${TAU_ACTIVE} --rho-update ${RHO_UPDATE} --delta-new ${DELTA_NEW} --latency ${LATENCY} --max-speakers ${MAX_SPEAKERS}"]

# Example run command with environment variables:
# docker run -p 7007:7007 --restart unless-stopped --gpus all \
# -e HF_TOKEN=<token> \
# -e HOST=0.0.0.0 \
# -e PORT=7007 \
# -e SEGMENTATION=pyannote/segmentation-3.0 \
# -e EMBEDDING=speechbrain/spkrec-resnet-voxceleb \
# -e TAU_ACTIVE=0.45 \
# -e RHO_UPDATE=0.25 \
# -e DELTA_NEW=0.6 \
# -e LATENCY=5 \
# -e MAX_SPEAKERS=3 \
# diart-image
70 changes: 57 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def embedding_loader():
segmentation = SegmentationModel(segmentation_loader)
embedding = EmbeddingModel(embedding_loader)
config = SpeakerDiarizationConfig(
# Set the segmentation model used in the paper
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct. To remove

segmentation=segmentation,
embedding=embedding,
)
Expand Down Expand Up @@ -284,21 +285,27 @@ Obtain overlap-aware speaker embeddings from a microphone stream:
```python
import rx.operators as ops
import diart.operators as dops
from diart.sources import MicrophoneAudioSource
from diart.sources import MicrophoneAudioSource, FileAudioSource
from diart.blocks import SpeakerSegmentation, OverlapAwareSpeakerEmbedding

segmentation = SpeakerSegmentation.from_pretrained("pyannote/segmentation")
embedding = OverlapAwareSpeakerEmbedding.from_pretrained("pyannote/embedding")
mic = MicrophoneAudioSource()

source = MicrophoneAudioSource()
# To take input from file:
# source = FileAudioSource("<filename>", sample_rate=16000)

# Make sure the models have been trained with this sample rate
print(source.sample_rate)

stream = mic.stream.pipe(
# Reformat stream to 5s duration and 500ms shift
dops.rearrange_audio_stream(sample_rate=segmentation.model.sample_rate),
dops.rearrange_audio_stream(sample_rate=source.sample_rate),
ops.map(lambda wav: (wav, segmentation(wav))),
ops.starmap(embedding)
).subscribe(on_next=lambda emb: print(emb.shape))

mic.read()
source.read()
```

Output:
Expand Down Expand Up @@ -326,20 +333,57 @@ diart.client microphone --host <server-address> --port 7007

See `-h` for more options.

### From the Dockerfile
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### From the Dockerfile
### From a Docker container


You can also run the server in a Docker container. First, build the image:
```shell
docker build -t diart -f Dockerfile .
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-f Dockerfile is not needed, as it will pick up the file with that name in the specified directory

```

Run the server with default configuration:
```shell
docker run -p 7007:7007 --gpus all -e HF_TOKEN=<token> diart
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add a note somewhere saying that for GPU usage they need to install nvidia-container-toolkit.

Also, is there a way to pick up the HF token from the huggingface-cli config? That way we avoid passing it directly and keeping it in the terminal history. This is possible when running outside docker, and we shouldn't make it mandatory, as it's an important security feature.

```

Run with custom configuration:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Run with custom configuration:
Example with a custom configuration:

```shell
docker run -p 7007:7007 --restart unless-stopped --gpus all \
-e HF_TOKEN=<token> \
-e HOST=0.0.0.0 \
-e PORT=7007 \
-e SEGMENTATION=pyannote/segmentation-3.0 \
-e EMBEDDING=speechbrain/spkrec-resnet-voxceleb \
-e TAU_ACTIVE=0.45 \
-e RHO_UPDATE=0.25 \
-e DELTA_NEW=0.6 \
-e LATENCY=5 \
-e MAX_SPEAKERS=3 \
diart
```

The server can be configured using these environment variables, at runtime:
- `HOST`: Server host (default: 0.0.0.0)
- `PORT`: Server port (default: 7007)
- `SEGMENTATION`: Segmentation model (default: pyannote/segmentation)
- `EMBEDDING`: Embedding model (default: pyannote/embedding)
- `TAU_ACTIVE`: Activity threshold (default: 0.5)
- `RHO_UPDATE`: Update threshold (default: 0.3)
- `DELTA_NEW`: New speaker threshold (default: 1.0)
- `LATENCY`: Processing latency in seconds (default: 0.5)
- `MAX_SPEAKERS`: Maximum number of speakers (default: 20)
Comment on lines +364 to +373
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved up before the example is given


### From python

For customized solutions, a server can also be created in python using the `WebSocketAudioSource`:
For customized solutions, a server can also be created in python using `WebSocketStreamingServer`:

```python
from diart import SpeakerDiarization
from diart.sources import WebSocketAudioSource
from diart.inference import StreamingInference
from diart import SpeakerDiarization, SpeakerDiarizationConfig
from diart.websockets import WebSocketStreamingServer

pipeline = SpeakerDiarization()
source = WebSocketAudioSource(pipeline.config.sample_rate, "localhost", 7007)
inference = StreamingInference(pipeline, source)
inference.attach_hooks(lambda ann_wav: source.send(ann_wav[0].to_rttm()))
prediction = inference()
pipeline_class = SpeakerDiarization
pipeline_config = SpeakerDiarizationConfig(step=0.5, sample_rate=16000)
server = WebSocketStreamingServer(pipeline_class, pipeline_config, host="localhost", port=7007)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
server = WebSocketStreamingServer(pipeline_class, pipeline_config, host="localhost", port=7007)
server = WebSocketStreamingServer(SpeakerDiarization, pipeline_config, host="localhost", port=7007)

I prefer to put the class name inline here, and to delete the definition of pipeline_class.
Also, let's rename pipeline_config to just config, as in all other examples in the readme.

server.run()
```

## 🔬 Powered by research
Expand Down
Binary file added assets/models/embedding_uint8.onnx
Binary file not shown.
Binary file added assets/models/segmentation_uint8.onnx
Binary file not shown.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.20.2
matplotlib>=3.3.3
numpy>=1.20.2,<2.0.0
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand Down
11 changes: 8 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name=diart
version=0.9.0
version=0.9.1
author=Juan Manuel Coria
description=A python framework to build AI for real-time speech
long_description=file: README.md
Expand All @@ -20,8 +20,8 @@ package_dir=
=src
packages=find:
install_requires=
numpy>=1.20.2
matplotlib>=3.3.3
numpy>=1.20.2,<2.0.0
matplotlib>=3.3.3,<3.6.0
rx>=3.2.0
scipy>=1.6.0
sounddevice>=0.4.2
Expand All @@ -41,6 +41,11 @@ install_requires=
websocket-client>=0.58.0
rich>=12.5.1

[options.extras_require]
tests=
pytest>=7.4.0,<8.0.0
onnxruntime==1.18.0

[options.packages.find]
where=src

Expand Down
Loading
Loading