2026-02-26 16:09:48 -08:00
|
|
|
"""Tests for diarization service data structures and payload conversion."""
|
|
|
|
|
|
2026-03-20 13:50:57 -07:00
|
|
|
import time
|
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
2026-02-26 16:09:48 -08:00
|
|
|
from voice_to_notes.services.diarize import (
|
|
|
|
|
DiarizationResult,
|
2026-03-20 13:50:57 -07:00
|
|
|
DiarizeService,
|
2026-02-26 16:09:48 -08:00
|
|
|
SpeakerSegment,
|
|
|
|
|
diarization_to_payload,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_diarization_to_payload():
|
|
|
|
|
result = DiarizationResult(
|
|
|
|
|
speaker_segments=[
|
|
|
|
|
SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=5000),
|
|
|
|
|
SpeakerSegment(speaker="SPEAKER_01", start_ms=5000, end_ms=10000),
|
|
|
|
|
SpeakerSegment(speaker="SPEAKER_00", start_ms=10000, end_ms=15000),
|
|
|
|
|
],
|
|
|
|
|
num_speakers=2,
|
|
|
|
|
speakers=["SPEAKER_00", "SPEAKER_01"],
|
|
|
|
|
)
|
|
|
|
|
payload = diarization_to_payload(result)
|
|
|
|
|
assert payload["num_speakers"] == 2
|
|
|
|
|
assert len(payload["speaker_segments"]) == 3
|
|
|
|
|
assert payload["speakers"] == ["SPEAKER_00", "SPEAKER_01"]
|
|
|
|
|
assert payload["speaker_segments"][0]["speaker"] == "SPEAKER_00"
|
|
|
|
|
assert payload["speaker_segments"][1]["start_ms"] == 5000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_diarization_to_payload_empty():
|
|
|
|
|
result = DiarizationResult()
|
|
|
|
|
payload = diarization_to_payload(result)
|
|
|
|
|
assert payload["num_speakers"] == 0
|
|
|
|
|
assert payload["speaker_segments"] == []
|
|
|
|
|
assert payload["speakers"] == []
|
2026-03-20 13:50:57 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_diarize_threading_progress(monkeypatch):
|
|
|
|
|
"""Test that diarization emits progress while running in background thread."""
|
|
|
|
|
# Track written messages
|
|
|
|
|
written_messages = []
|
|
|
|
|
def mock_write(msg):
|
|
|
|
|
written_messages.append(msg)
|
|
|
|
|
|
|
|
|
|
# Mock pipeline that takes ~5 seconds
|
|
|
|
|
def slow_pipeline(file_path, **kwargs):
|
|
|
|
|
time.sleep(5)
|
2026-03-20 13:56:32 -07:00
|
|
|
# Return a mock diarization result (use spec=object to prevent
|
|
|
|
|
# hasattr returning True for speaker_diarization)
|
|
|
|
|
mock_result = MagicMock(spec=[])
|
2026-03-20 13:50:57 -07:00
|
|
|
mock_track = MagicMock()
|
|
|
|
|
mock_track.start = 0.0
|
|
|
|
|
mock_track.end = 5.0
|
|
|
|
|
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
|
|
|
|
|
return mock_result
|
|
|
|
|
|
|
|
|
|
mock_pipeline_obj = MagicMock()
|
|
|
|
|
mock_pipeline_obj.side_effect = slow_pipeline
|
|
|
|
|
|
|
|
|
|
service = DiarizeService()
|
|
|
|
|
service._pipeline = mock_pipeline_obj
|
|
|
|
|
|
|
|
|
|
with patch("voice_to_notes.services.diarize.write_message", mock_write):
|
|
|
|
|
result = service.diarize(
|
|
|
|
|
request_id="req-1",
|
|
|
|
|
file_path="/fake/audio.wav",
|
|
|
|
|
audio_duration_sec=60.0,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Filter for diarizing progress messages (not loading_diarization or done)
|
|
|
|
|
diarizing_msgs = [
|
|
|
|
|
m for m in written_messages
|
|
|
|
|
if m.type == "progress" and m.payload.get("stage") == "diarizing"
|
|
|
|
|
and "elapsed" in m.payload.get("message", "")
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
|
|
|
|
|
assert len(diarizing_msgs) >= 1, (
|
|
|
|
|
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Progress percent should be between 20 and 85
|
|
|
|
|
for msg in diarizing_msgs:
|
|
|
|
|
pct = msg.payload["percent"]
|
|
|
|
|
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
|
|
|
|
|
|
|
|
|
|
# Result should be valid
|
|
|
|
|
assert result.num_speakers == 1
|
|
|
|
|
assert result.speakers == ["SPEAKER_00"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_diarize_threading_error_propagation(monkeypatch):
|
|
|
|
|
"""Test that errors from the background thread are properly raised."""
|
|
|
|
|
mock_pipeline_obj = MagicMock()
|
|
|
|
|
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
|
|
|
|
|
|
|
|
|
|
service = DiarizeService()
|
|
|
|
|
service._pipeline = mock_pipeline_obj
|
|
|
|
|
|
|
|
|
|
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
|
|
|
|
|
with pytest.raises(RuntimeError, match="Pipeline crashed"):
|
|
|
|
|
service.diarize(
|
|
|
|
|
request_id="req-1",
|
|
|
|
|
file_path="/fake/audio.wav",
|
|
|
|
|
audio_duration_sec=30.0,
|
|
|
|
|
)
|