perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
3 changed files with 113 additions and 6 deletions
Showing only changes of commit c23b9a90dd - Show all commits

View File

@@ -1,7 +1,13 @@
"""Tests for diarization service data structures and payload conversion.""" """Tests for diarization service data structures and payload conversion."""
import time
from unittest.mock import MagicMock, patch
import pytest
from voice_to_notes.services.diarize import ( from voice_to_notes.services.diarize import (
DiarizationResult, DiarizationResult,
DiarizeService,
SpeakerSegment, SpeakerSegment,
diarization_to_payload, diarization_to_payload,
) )
@@ -31,3 +37,73 @@ def test_diarization_to_payload_empty():
assert payload["num_speakers"] == 0 assert payload["num_speakers"] == 0
assert payload["speaker_segments"] == [] assert payload["speaker_segments"] == []
assert payload["speakers"] == [] assert payload["speakers"] == []
def test_diarize_threading_progress(monkeypatch):
"""Test that diarization emits progress while running in background thread."""
# Track written messages
written_messages = []
def mock_write(msg):
written_messages.append(msg)
# Mock pipeline that takes ~5 seconds
def slow_pipeline(file_path, **kwargs):
time.sleep(5)
# Return a mock diarization result
mock_result = MagicMock()
mock_track = MagicMock()
mock_track.start = 0.0
mock_track.end = 5.0
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
return mock_result
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = slow_pipeline
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", mock_write):
result = service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=60.0,
)
# Filter for diarizing progress messages (not loading_diarization or done)
diarizing_msgs = [
m for m in written_messages
if m.type == "progress" and m.payload.get("stage") == "diarizing"
and "elapsed" in m.payload.get("message", "")
]
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
assert len(diarizing_msgs) >= 1, (
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
)
# Progress percent should be between 20 and 85
for msg in diarizing_msgs:
pct = msg.payload["percent"]
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
# Result should be valid
assert result.num_speakers == 1
assert result.speakers == ["SPEAKER_00"]
def test_diarize_threading_error_propagation(monkeypatch):
"""Test that errors from the background thread are properly raised."""
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
with pytest.raises(RuntimeError, match="Pipeline crashed"):
service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=30.0,
)

View File

@@ -6,6 +6,7 @@ import os
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import threading
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
@@ -141,6 +142,7 @@ class DiarizeService:
min_speakers: int | None = None, min_speakers: int | None = None,
max_speakers: int | None = None, max_speakers: int | None = None,
hf_token: str | None = None, hf_token: str | None = None,
audio_duration_sec: float | None = None,
) -> DiarizationResult: ) -> DiarizationResult:
"""Run speaker diarization on an audio file. """Run speaker diarization on an audio file.
@@ -184,13 +186,41 @@ class DiarizeService:
flush=True, flush=True,
) )
# Run diarization # Run diarization in background thread for progress reporting
result_holder: list = [None]
error_holder: list[Exception | None] = [None]
done_event = threading.Event()
def _run():
try: try:
raw_result = pipeline(audio_path, **kwargs) result_holder[0] = pipeline(audio_path, **kwargs)
except Exception as e:
error_holder[0] = e
finally: finally:
done_event.set()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
elapsed = 0.0
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
while not done_event.wait(timeout=2.0):
elapsed += 2.0
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
write_message(progress_message(
request_id, pct, "diarizing",
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
thread.join()
# Clean up temp file
if temp_wav: if temp_wav:
os.unlink(temp_wav) os.unlink(temp_wav)
if error_holder[0] is not None:
raise error_holder[0]
raw_result = result_holder[0]
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly # pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
if hasattr(raw_result, "speaker_diarization"): if hasattr(raw_result, "speaker_diarization"):
diarization = raw_result.speaker_diarization diarization = raw_result.speaker_diarization

View File

@@ -139,6 +139,7 @@ class PipelineService:
min_speakers=min_speakers, min_speakers=min_speakers,
max_speakers=max_speakers, max_speakers=max_speakers,
hf_token=hf_token, hf_token=hf_token,
audio_duration_sec=transcription.duration_ms / 1000.0,
) )
except Exception as e: except Exception as e:
import traceback import traceback