Skip to content

Commit

Permalink
fix: openai api transcriber (#652)
Browse files Browse the repository at this point in the history
  • Loading branch information
chidiwilliams authored Dec 23, 2023
1 parent f163aab commit 1120723
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 44 deletions.
2 changes: 2 additions & 0 deletions buzz/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Key(enum.Enum):
"transcription-tasks-table/column-visibility"
)

MAIN_WINDOW = "main-window"

def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)

Expand Down
115 changes: 80 additions & 35 deletions buzz/transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import enum
import json
import logging
import math
import multiprocessing
import os
import subprocess
Expand Down Expand Up @@ -286,56 +287,100 @@ def transcribe(self) -> List[Segment]:
self.task,
)

wav_file = tempfile.mktemp() + ".wav"
mp3_file = tempfile.mktemp() + ".mp3"

# fmt: off
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", self.file_path,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(whisper.audio.SAMPLE_RATE),
wav_file,
]
# fmt: on
cmd = ["ffmpeg", "-i", self.file_path, mp3_file]

try:
subprocess.run(cmd, capture_output=True, check=True)
except subprocess.CalledProcessError as exc:
logging.exception("")
raise Exception(exc.stderr.decode("utf-8"))

# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
# fmt: off
cmd = [
"ffprobe",
"-v", "error",
"-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1",
mp3_file,
]
# fmt: on
duration_secs = float(
subprocess.run(cmd, capture_output=True, check=True).stdout.decode("utf-8")
)

total_size = os.path.getsize(mp3_file)
max_chunk_size = 25 * 1024 * 1024

openai.api_key = (
self.transcription_task.transcription_options.openai_access_token
)
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate(
"whisper-1",
audio_file,
response_format=response_format,
language=language,
)
else:
transcript = openai.Audio.transcribe(
"whisper-1",
audio_file,
response_format=response_format,
language=language,

self.progress.emit((0, 100))

if total_size < max_chunk_size:
return self.get_segments_for_file(mp3_file)

# If the file is larger than 25MB, split into chunks
# and transcribe each chunk separately
num_chunks = math.ceil(total_size / max_chunk_size)
chunk_duration = duration_secs / num_chunks

segments = []

for i in range(num_chunks):
chunk_start = i * chunk_duration
chunk_end = min((i + 1) * chunk_duration, duration_secs)

chunk_file = tempfile.mktemp() + ".mp3"

# fmt: off
cmd = [
"ffmpeg",
"-i", mp3_file,
"-ss", str(chunk_start),
"-to", str(chunk_end),
"-c", "copy",
chunk_file,
]
# fmt: on
subprocess.run(cmd, capture_output=True, check=True)
logging.debug('Created chunk file "%s"', chunk_file)

segments.extend(
self.get_segments_for_file(
chunk_file, offset_ms=int(chunk_start * 1000)
)
)
os.remove(chunk_file)
self.progress.emit((i + 1, num_chunks))

segments = [
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"])
for segment in transcript["segments"]
]
return segments

def get_segments_for_file(self, file: str, offset_ms: int = 0):
with open(file, "rb") as audio_file:
kwargs = {
"model": "whisper-1",
"file": audio_file,
"response_format": "verbose_json",
"language": self.transcription_task.transcription_options.language,
}
transcript = (
openai.Audio.translate(**kwargs)
if self.transcription_task.transcription_options.task == Task.TRANSLATE
else openai.Audio.transcribe(**kwargs)
)

return [
Segment(
int(segment["start"] * 1000 + offset_ms),
int(segment["end"] * 1000 + offset_ms),
segment["text"],
)
for segment in transcript["segments"]
]

def stop(self):
pass

Expand Down
19 changes: 19 additions & 0 deletions buzz/widgets/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ def __init__(self, tasks_cache=TasksCache()):

self.load_tasks_from_cache()

self.load_geometry()

def dragEnterEvent(self, event):
# Accept file drag events
if event.mimeData().hasUrls():
Expand Down Expand Up @@ -314,10 +316,27 @@ def on_shortcuts_changed(self, shortcuts: dict):
self.toolbar.set_shortcuts(shortcuts=self.shortcuts)
self.shortcut_settings.save(shortcuts=self.shortcuts)

def resizeEvent(self, event):
self.save_geometry()

def closeEvent(self, event: QtGui.QCloseEvent) -> None:
self.save_geometry()

self.transcriber_worker.stop()
self.transcriber_thread.quit()
self.transcriber_thread.wait()
self.save_tasks_to_cache()
self.shortcut_settings.save(shortcuts=self.shortcuts)
super().closeEvent(event)

def save_geometry(self):
self.settings.begin_group(Settings.Key.MAIN_WINDOW)
self.settings.settings.setValue("geometry", self.saveGeometry())
self.settings.end_group()

def load_geometry(self):
self.settings.begin_group(Settings.Key.MAIN_WINDOW)
geometry = self.settings.settings.value("geometry")
if geometry is not None:
self.restoreGeometry(geometry)
self.settings.end_group()
12 changes: 7 additions & 5 deletions buzz/widgets/transcriber/file_transcriber_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ def __init__(
self.setLayout(layout)
self.setFixedSize(self.sizeHint())

self.reset_transcriber_controls()

def get_on_checkbox_state_changed_callback(self, output_format: OutputFormat):
def on_checkbox_state_changed(state: int):
if state == Qt.CheckState.Checked.value:
Expand All @@ -158,11 +160,6 @@ def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
or self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)
if self.transcription_options.openai_access_token != "":
self.openai_access_token_changed.emit(
self.transcription_options.openai_access_token
Expand Down Expand Up @@ -213,6 +210,11 @@ def on_download_model_error(self, error: str):

def reset_transcriber_controls(self):
self.run_button.setDisabled(False)
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
or self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)

def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
Expand Down
33 changes: 29 additions & 4 deletions buzz/widgets/transcription_tasks_table_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TableColDef:
id: str
header: str
column_index: int
value_getter: Callable[..., str]
value_getter: Callable[[FileTranscriptionTask], str]
width: Optional[int] = None
hidden: bool = False
hidden_toggleable: bool = True
Expand All @@ -37,6 +37,8 @@ class Column(enum.Enum):
MODEL = auto()
TASK = auto()
STATUS = auto()
DATE_ADDED = auto()
DATE_COMPLETED = auto()

return_clicked = pyqtSignal()

Expand Down Expand Up @@ -78,7 +80,7 @@ def __init__(self, parent: Optional[QWidget] = None):
header=_("Task"),
column_index=self.Column.TASK.value,
value_getter=lambda task: self.get_task_label(task),
width=180,
width=120,
hidden=True,
),
TableColDef(
Expand All @@ -89,6 +91,28 @@ def __init__(self, parent: Optional[QWidget] = None):
width=180,
hidden_toggleable=False,
),
TableColDef(
id="date_added",
header=_("Date Added"),
column_index=self.Column.DATE_ADDED.value,
value_getter=lambda task: task.queued_at.strftime("%Y-%m-%d %H:%M:%S")
if task.queued_at is not None
else "",
width=180,
hidden=True,
),
TableColDef(
id="date_completed",
header=_("Date Completed"),
column_index=self.Column.DATE_COMPLETED.value,
value_getter=lambda task: task.completed_at.strftime(
"%Y-%m-%d %H:%M:%S"
)
if task.completed_at is not None
else "",
width=180,
hidden=True,
),
]

self.setColumnCount(len(self.column_definitions))
Expand Down Expand Up @@ -155,8 +179,9 @@ def upsert_task(self, task: FileTranscriptionTask):
item.setFlags(item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, definition.column_index, item)
else:
status_widget = self.item(task_row_index, self.Column.STATUS.value)
status_widget.setText(task.status_text())
for definition in self.column_definitions:
item = self.item(task_row_index, definition.column_index)
item.setText(definition.value_getter(task))

@staticmethod
def get_task_label(task: FileTranscriptionTask) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def set_segment_text(self, text: str):
self.task_changed.emit()


# TODO: Fix player duration and add spacer below
class TranscriptionViewerWidget(QWidget):
transcription_task: FileTranscriptionTask
task_changed = pyqtSignal()
Expand Down
32 changes: 32 additions & 0 deletions tests/transcriber_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
whisper_cpp_params,
write_output,
TranscriptionOptions,
OpenAIWhisperAPIFileTranscriber,
)
from buzz.recording_transcriber import RecordingTranscriber
from tests.mock_sounddevice import MockInputStream
Expand Down Expand Up @@ -70,6 +71,37 @@ def test_should_transcribe(self, qtbot):
assert "Bienvenue dans Passe" in text


class TestOpenAIWhisperAPIFileTranscriber:
def test_transcribe(self):
file_path = "testdata/whisper-french.mp3"
transcriber = OpenAIWhisperAPIFileTranscriber(
task=FileTranscriptionTask(
file_path=file_path,
transcription_options=(
TranscriptionOptions(
openai_access_token=os.getenv("OPENAI_ACCESS_TOKEN"),
)
),
file_transcription_options=(
FileTranscriptionOptions(file_paths=[file_path])
),
model_path="",
)
)
mock_completed = Mock()
transcriber.completed.connect(mock_completed)
mock_openai_result = {"segments": [{"start": 0, "end": 6.56, "text": "Hello"}]}
with patch("openai.Audio.transcribe", return_value=mock_openai_result):
transcriber.run()

called_segments = mock_completed.call_args[0][0]

assert len(called_segments) == 1
assert called_segments[0].start == 0
assert called_segments[0].end == 6560
assert called_segments[0].text == "Hello"


class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
"word_level_timings,expected_segments",
Expand Down

0 comments on commit 1120723

Please sign in to comment.