Merge perf/diarize-threading: diarization progress via background thread
This commit is contained in:
@@ -1,7 +1,13 @@
|
||||
"""Tests for diarization service data structures and payload conversion."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from voice_to_notes.services.diarize import (
|
||||
DiarizationResult,
|
||||
DiarizeService,
|
||||
SpeakerSegment,
|
||||
diarization_to_payload,
|
||||
)
|
||||
@@ -31,3 +37,73 @@ def test_diarization_to_payload_empty():
|
||||
assert payload["num_speakers"] == 0
|
||||
assert payload["speaker_segments"] == []
|
||||
assert payload["speakers"] == []
|
||||
|
||||
|
||||
def test_diarize_threading_progress(monkeypatch):
|
||||
"""Test that diarization emits progress while running in background thread."""
|
||||
# Track written messages
|
||||
written_messages = []
|
||||
def mock_write(msg):
|
||||
written_messages.append(msg)
|
||||
|
||||
# Mock pipeline that takes ~5 seconds
|
||||
def slow_pipeline(file_path, **kwargs):
|
||||
time.sleep(5)
|
||||
# Return a mock diarization result
|
||||
mock_result = MagicMock()
|
||||
mock_track = MagicMock()
|
||||
mock_track.start = 0.0
|
||||
mock_track.end = 5.0
|
||||
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
|
||||
return mock_result
|
||||
|
||||
mock_pipeline_obj = MagicMock()
|
||||
mock_pipeline_obj.side_effect = slow_pipeline
|
||||
|
||||
service = DiarizeService()
|
||||
service._pipeline = mock_pipeline_obj
|
||||
|
||||
with patch("voice_to_notes.services.diarize.write_message", mock_write):
|
||||
result = service.diarize(
|
||||
request_id="req-1",
|
||||
file_path="/fake/audio.wav",
|
||||
audio_duration_sec=60.0,
|
||||
)
|
||||
|
||||
# Filter for diarizing progress messages (not loading_diarization or done)
|
||||
diarizing_msgs = [
|
||||
m for m in written_messages
|
||||
if m.type == "progress" and m.payload.get("stage") == "diarizing"
|
||||
and "elapsed" in m.payload.get("message", "")
|
||||
]
|
||||
|
||||
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
|
||||
assert len(diarizing_msgs) >= 1, (
|
||||
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
|
||||
)
|
||||
|
||||
# Progress percent should be between 20 and 85
|
||||
for msg in diarizing_msgs:
|
||||
pct = msg.payload["percent"]
|
||||
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
|
||||
|
||||
# Result should be valid
|
||||
assert result.num_speakers == 1
|
||||
assert result.speakers == ["SPEAKER_00"]
|
||||
|
||||
|
||||
def test_diarize_threading_error_propagation(monkeypatch):
|
||||
"""Test that errors from the background thread are properly raised."""
|
||||
mock_pipeline_obj = MagicMock()
|
||||
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
|
||||
|
||||
service = DiarizeService()
|
||||
service._pipeline = mock_pipeline_obj
|
||||
|
||||
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
|
||||
with pytest.raises(RuntimeError, match="Pipeline crashed"):
|
||||
service.diarize(
|
||||
request_id="req-1",
|
||||
file_path="/fake/audio.wav",
|
||||
audio_duration_sec=30.0,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
@@ -141,6 +142,7 @@ class DiarizeService:
|
||||
min_speakers: int | None = None,
|
||||
max_speakers: int | None = None,
|
||||
hf_token: str | None = None,
|
||||
audio_duration_sec: float | None = None,
|
||||
) -> DiarizationResult:
|
||||
"""Run speaker diarization on an audio file.
|
||||
|
||||
@@ -184,13 +186,41 @@ class DiarizeService:
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Run diarization
|
||||
# Run diarization in background thread for progress reporting
|
||||
result_holder: list = [None]
|
||||
error_holder: list[Exception | None] = [None]
|
||||
done_event = threading.Event()
|
||||
|
||||
def _run():
|
||||
try:
|
||||
raw_result = pipeline(audio_path, **kwargs)
|
||||
result_holder[0] = pipeline(audio_path, **kwargs)
|
||||
except Exception as e:
|
||||
error_holder[0] = e
|
||||
finally:
|
||||
done_event.set()
|
||||
|
||||
thread = threading.Thread(target=_run, daemon=True)
|
||||
thread.start()
|
||||
|
||||
elapsed = 0.0
|
||||
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
|
||||
while not done_event.wait(timeout=2.0):
|
||||
elapsed += 2.0
|
||||
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
|
||||
write_message(progress_message(
|
||||
request_id, pct, "diarizing",
|
||||
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
|
||||
|
||||
thread.join()
|
||||
|
||||
# Clean up temp file
|
||||
if temp_wav:
|
||||
os.unlink(temp_wav)
|
||||
|
||||
if error_holder[0] is not None:
|
||||
raise error_holder[0]
|
||||
raw_result = result_holder[0]
|
||||
|
||||
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
|
||||
if hasattr(raw_result, "speaker_diarization"):
|
||||
diarization = raw_result.speaker_diarization
|
||||
|
||||
@@ -139,6 +139,7 @@ class PipelineService:
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
hf_token=hf_token,
|
||||
audio_duration_sec=transcription.duration_ms / 1000.0,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
Reference in New Issue
Block a user