Merge perf/diarize-threading: diarization progress via background thread
This commit is contained in:
@@ -1,7 +1,13 @@
|
|||||||
"""Tests for diarization service data structures and payload conversion."""
|
"""Tests for diarization service data structures and payload conversion."""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from voice_to_notes.services.diarize import (
|
from voice_to_notes.services.diarize import (
|
||||||
DiarizationResult,
|
DiarizationResult,
|
||||||
|
DiarizeService,
|
||||||
SpeakerSegment,
|
SpeakerSegment,
|
||||||
diarization_to_payload,
|
diarization_to_payload,
|
||||||
)
|
)
|
||||||
@@ -31,3 +37,73 @@ def test_diarization_to_payload_empty():
|
|||||||
assert payload["num_speakers"] == 0
|
assert payload["num_speakers"] == 0
|
||||||
assert payload["speaker_segments"] == []
|
assert payload["speaker_segments"] == []
|
||||||
assert payload["speakers"] == []
|
assert payload["speakers"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarize_threading_progress(monkeypatch):
|
||||||
|
"""Test that diarization emits progress while running in background thread."""
|
||||||
|
# Track written messages
|
||||||
|
written_messages = []
|
||||||
|
def mock_write(msg):
|
||||||
|
written_messages.append(msg)
|
||||||
|
|
||||||
|
# Mock pipeline that takes ~5 seconds
|
||||||
|
def slow_pipeline(file_path, **kwargs):
|
||||||
|
time.sleep(5)
|
||||||
|
# Return a mock diarization result
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_track = MagicMock()
|
||||||
|
mock_track.start = 0.0
|
||||||
|
mock_track.end = 5.0
|
||||||
|
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
|
||||||
|
return mock_result
|
||||||
|
|
||||||
|
mock_pipeline_obj = MagicMock()
|
||||||
|
mock_pipeline_obj.side_effect = slow_pipeline
|
||||||
|
|
||||||
|
service = DiarizeService()
|
||||||
|
service._pipeline = mock_pipeline_obj
|
||||||
|
|
||||||
|
with patch("voice_to_notes.services.diarize.write_message", mock_write):
|
||||||
|
result = service.diarize(
|
||||||
|
request_id="req-1",
|
||||||
|
file_path="/fake/audio.wav",
|
||||||
|
audio_duration_sec=60.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter for diarizing progress messages (not loading_diarization or done)
|
||||||
|
diarizing_msgs = [
|
||||||
|
m for m in written_messages
|
||||||
|
if m.type == "progress" and m.payload.get("stage") == "diarizing"
|
||||||
|
and "elapsed" in m.payload.get("message", "")
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
|
||||||
|
assert len(diarizing_msgs) >= 1, (
|
||||||
|
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Progress percent should be between 20 and 85
|
||||||
|
for msg in diarizing_msgs:
|
||||||
|
pct = msg.payload["percent"]
|
||||||
|
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
|
||||||
|
|
||||||
|
# Result should be valid
|
||||||
|
assert result.num_speakers == 1
|
||||||
|
assert result.speakers == ["SPEAKER_00"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarize_threading_error_propagation(monkeypatch):
|
||||||
|
"""Test that errors from the background thread are properly raised."""
|
||||||
|
mock_pipeline_obj = MagicMock()
|
||||||
|
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
|
||||||
|
|
||||||
|
service = DiarizeService()
|
||||||
|
service._pipeline = mock_pipeline_obj
|
||||||
|
|
||||||
|
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
|
||||||
|
with pytest.raises(RuntimeError, match="Pipeline crashed"):
|
||||||
|
service.diarize(
|
||||||
|
request_id="req-1",
|
||||||
|
file_path="/fake/audio.wav",
|
||||||
|
audio_duration_sec=30.0,
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -141,6 +142,7 @@ class DiarizeService:
|
|||||||
min_speakers: int | None = None,
|
min_speakers: int | None = None,
|
||||||
max_speakers: int | None = None,
|
max_speakers: int | None = None,
|
||||||
hf_token: str | None = None,
|
hf_token: str | None = None,
|
||||||
|
audio_duration_sec: float | None = None,
|
||||||
) -> DiarizationResult:
|
) -> DiarizationResult:
|
||||||
"""Run speaker diarization on an audio file.
|
"""Run speaker diarization on an audio file.
|
||||||
|
|
||||||
@@ -184,12 +186,40 @@ class DiarizeService:
|
|||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run diarization
|
# Run diarization in background thread for progress reporting
|
||||||
try:
|
result_holder: list = [None]
|
||||||
raw_result = pipeline(audio_path, **kwargs)
|
error_holder: list[Exception | None] = [None]
|
||||||
finally:
|
done_event = threading.Event()
|
||||||
if temp_wav:
|
|
||||||
os.unlink(temp_wav)
|
def _run():
|
||||||
|
try:
|
||||||
|
result_holder[0] = pipeline(audio_path, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
error_holder[0] = e
|
||||||
|
finally:
|
||||||
|
done_event.set()
|
||||||
|
|
||||||
|
thread = threading.Thread(target=_run, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
elapsed = 0.0
|
||||||
|
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
|
||||||
|
while not done_event.wait(timeout=2.0):
|
||||||
|
elapsed += 2.0
|
||||||
|
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
|
||||||
|
write_message(progress_message(
|
||||||
|
request_id, pct, "diarizing",
|
||||||
|
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
|
||||||
|
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Clean up temp file
|
||||||
|
if temp_wav:
|
||||||
|
os.unlink(temp_wav)
|
||||||
|
|
||||||
|
if error_holder[0] is not None:
|
||||||
|
raise error_holder[0]
|
||||||
|
raw_result = result_holder[0]
|
||||||
|
|
||||||
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
|
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
|
||||||
if hasattr(raw_result, "speaker_diarization"):
|
if hasattr(raw_result, "speaker_diarization"):
|
||||||
|
|||||||
@@ -139,6 +139,7 @@ class PipelineService:
|
|||||||
min_speakers=min_speakers,
|
min_speakers=min_speakers,
|
||||||
max_speakers=max_speakers,
|
max_speakers=max_speakers,
|
||||||
hf_token=hf_token,
|
hf_token=hf_token,
|
||||||
|
audio_duration_sec=transcription.duration_ms / 1000.0,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|||||||
Reference in New Issue
Block a user