Run pyannote diarization in background thread with progress reporting

Move the blocking pipeline() call to a daemon thread and emit estimated
progress messages every 2 seconds from the main thread. The progress
estimate uses audio duration to calibrate the expected total time.
Also pass audio_duration_sec from PipelineService to DiarizeService.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Claude
2026-03-20 13:50:57 -07:00
parent d00281f0c7
commit 03af5a189c
3 changed files with 110 additions and 2 deletions

View File

@@ -1,7 +1,13 @@
"""Tests for diarization service data structures and payload conversion."""
import time
from unittest.mock import MagicMock, patch
import pytest
from voice_to_notes.services.diarize import (
DiarizationResult,
DiarizeService,
SpeakerSegment,
diarization_to_payload,
)
@@ -31,3 +37,73 @@ def test_diarization_to_payload_empty():
assert payload["num_speakers"] == 0
assert payload["speaker_segments"] == []
assert payload["speakers"] == []
def test_diarize_threading_progress(monkeypatch):
"""Test that diarization emits progress while running in background thread."""
# Track written messages
written_messages = []
def mock_write(msg):
written_messages.append(msg)
# Mock pipeline that takes ~5 seconds
def slow_pipeline(file_path, **kwargs):
time.sleep(5)
# Return a mock diarization result
mock_result = MagicMock()
mock_track = MagicMock()
mock_track.start = 0.0
mock_track.end = 5.0
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
return mock_result
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = slow_pipeline
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", mock_write):
result = service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=60.0,
)
# Filter for diarizing progress messages (not loading_diarization or done)
diarizing_msgs = [
m for m in written_messages
if m.type == "progress" and m.payload.get("stage") == "diarizing"
and "elapsed" in m.payload.get("message", "")
]
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
assert len(diarizing_msgs) >= 1, (
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
)
# Progress percent should be between 20 and 85
for msg in diarizing_msgs:
pct = msg.payload["percent"]
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
# Result should be valid
assert result.num_speakers == 1
assert result.speakers == ["SPEAKER_00"]
def test_diarize_threading_error_propagation(monkeypatch):
"""Test that errors from the background thread are properly raised."""
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
with pytest.raises(RuntimeError, match="Pipeline crashed"):
service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=30.0,
)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import sys
import threading
import time
from dataclasses import dataclass, field
from typing import Any
@@ -82,6 +83,8 @@ class DiarizeService:
num_speakers: int | None = None,
min_speakers: int | None = None,
max_speakers: int | None = None,
hf_token: str | None = None,
audio_duration_sec: float | None = None,
) -> DiarizationResult:
"""Run speaker diarization on an audio file.
@@ -116,8 +119,36 @@ class DiarizeService:
if max_speakers is not None:
kwargs["max_speakers"] = max_speakers
# Run diarization
diarization = pipeline(file_path, **kwargs)
# Run diarization in background thread for progress reporting
result_holder: list = [None]
error_holder: list[Exception | None] = [None]
done_event = threading.Event()
def _run():
try:
result_holder[0] = pipeline(file_path, **kwargs)
except Exception as e:
error_holder[0] = e
finally:
done_event.set()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
elapsed = 0.0
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
while not done_event.wait(timeout=2.0):
elapsed += 2.0
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
write_message(progress_message(
request_id, pct, "diarizing",
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
thread.join()
if error_holder[0] is not None:
raise error_holder[0]
diarization = result_holder[0]
# Convert pyannote output to our format
result = DiarizationResult()

View File

@@ -121,6 +121,7 @@ class PipelineService:
num_speakers=num_speakers,
min_speakers=min_speakers,
max_speakers=max_speakers,
audio_duration_sec=transcription.duration_ms / 1000.0,
)
# Step 3: Merge