perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
3 changed files with 110 additions and 2 deletions
Showing only changes of commit 03af5a189c - Show all commits

View File

@@ -1,7 +1,13 @@
"""Tests for diarization service data structures and payload conversion.""" """Tests for diarization service data structures and payload conversion."""
import time
from unittest.mock import MagicMock, patch
import pytest
from voice_to_notes.services.diarize import ( from voice_to_notes.services.diarize import (
DiarizationResult, DiarizationResult,
DiarizeService,
SpeakerSegment, SpeakerSegment,
diarization_to_payload, diarization_to_payload,
) )
@@ -31,3 +37,73 @@ def test_diarization_to_payload_empty():
assert payload["num_speakers"] == 0 assert payload["num_speakers"] == 0
assert payload["speaker_segments"] == [] assert payload["speaker_segments"] == []
assert payload["speakers"] == [] assert payload["speakers"] == []
def test_diarize_threading_progress(monkeypatch):
"""Test that diarization emits progress while running in background thread."""
# Track written messages
written_messages = []
def mock_write(msg):
written_messages.append(msg)
# Mock pipeline that takes ~5 seconds
def slow_pipeline(file_path, **kwargs):
time.sleep(5)
# Return a mock diarization result
mock_result = MagicMock()
mock_track = MagicMock()
mock_track.start = 0.0
mock_track.end = 5.0
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
return mock_result
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = slow_pipeline
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", mock_write):
result = service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=60.0,
)
# Filter for diarizing progress messages (not loading_diarization or done)
diarizing_msgs = [
m for m in written_messages
if m.type == "progress" and m.payload.get("stage") == "diarizing"
and "elapsed" in m.payload.get("message", "")
]
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
assert len(diarizing_msgs) >= 1, (
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
)
# Progress percent should be between 20 and 85
for msg in diarizing_msgs:
pct = msg.payload["percent"]
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
# Result should be valid
assert result.num_speakers == 1
assert result.speakers == ["SPEAKER_00"]
def test_diarize_threading_error_propagation(monkeypatch):
"""Test that errors from the background thread are properly raised."""
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
with pytest.raises(RuntimeError, match="Pipeline crashed"):
service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=30.0,
)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
import threading
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -82,6 +83,8 @@ class DiarizeService:
num_speakers: int | None = None, num_speakers: int | None = None,
min_speakers: int | None = None, min_speakers: int | None = None,
max_speakers: int | None = None, max_speakers: int | None = None,
hf_token: str | None = None,
audio_duration_sec: float | None = None,
) -> DiarizationResult: ) -> DiarizationResult:
"""Run speaker diarization on an audio file. """Run speaker diarization on an audio file.
@@ -116,8 +119,36 @@ class DiarizeService:
if max_speakers is not None: if max_speakers is not None:
kwargs["max_speakers"] = max_speakers kwargs["max_speakers"] = max_speakers
# Run diarization # Run diarization in background thread for progress reporting
diarization = pipeline(file_path, **kwargs) result_holder: list = [None]
error_holder: list[Exception | None] = [None]
done_event = threading.Event()
def _run():
try:
result_holder[0] = pipeline(file_path, **kwargs)
except Exception as e:
error_holder[0] = e
finally:
done_event.set()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
elapsed = 0.0
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
while not done_event.wait(timeout=2.0):
elapsed += 2.0
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
write_message(progress_message(
request_id, pct, "diarizing",
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
thread.join()
if error_holder[0] is not None:
raise error_holder[0]
diarization = result_holder[0]
# Convert pyannote output to our format # Convert pyannote output to our format
result = DiarizationResult() result = DiarizationResult()

View File

@@ -121,6 +121,7 @@ class PipelineService:
num_speakers=num_speakers, num_speakers=num_speakers,
min_speakers=min_speakers, min_speakers=min_speakers,
max_speakers=max_speakers, max_speakers=max_speakers,
audio_duration_sec=transcription.duration_ms / 1000.0,
) )
# Step 3: Merge # Step 3: Merge