Files

111 lines
3.8 KiB
Python
Raw Permalink Normal View History

"""Tests for diarization service data structures and payload conversion."""
import time
from unittest.mock import MagicMock, patch
import pytest
from voice_to_notes.services.diarize import (
DiarizationResult,
DiarizeService,
SpeakerSegment,
diarization_to_payload,
)
def test_diarization_to_payload():
result = DiarizationResult(
speaker_segments=[
SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=5000),
SpeakerSegment(speaker="SPEAKER_01", start_ms=5000, end_ms=10000),
SpeakerSegment(speaker="SPEAKER_00", start_ms=10000, end_ms=15000),
],
num_speakers=2,
speakers=["SPEAKER_00", "SPEAKER_01"],
)
payload = diarization_to_payload(result)
assert payload["num_speakers"] == 2
assert len(payload["speaker_segments"]) == 3
assert payload["speakers"] == ["SPEAKER_00", "SPEAKER_01"]
assert payload["speaker_segments"][0]["speaker"] == "SPEAKER_00"
assert payload["speaker_segments"][1]["start_ms"] == 5000
def test_diarization_to_payload_empty():
result = DiarizationResult()
payload = diarization_to_payload(result)
assert payload["num_speakers"] == 0
assert payload["speaker_segments"] == []
assert payload["speakers"] == []
def test_diarize_threading_progress(monkeypatch):
"""Test that diarization emits progress while running in background thread."""
# Track written messages
written_messages = []
def mock_write(msg):
written_messages.append(msg)
# Mock pipeline that takes ~5 seconds
def slow_pipeline(file_path, **kwargs):
time.sleep(5)
# Return a mock diarization result (use spec=object to prevent
# hasattr returning True for speaker_diarization)
mock_result = MagicMock(spec=[])
mock_track = MagicMock()
mock_track.start = 0.0
mock_track.end = 5.0
mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")])
return mock_result
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = slow_pipeline
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", mock_write):
result = service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=60.0,
)
# Filter for diarizing progress messages (not loading_diarization or done)
diarizing_msgs = [
m for m in written_messages
if m.type == "progress" and m.payload.get("stage") == "diarizing"
and "elapsed" in m.payload.get("message", "")
]
# Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages)
assert len(diarizing_msgs) >= 1, (
f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}"
)
# Progress percent should be between 20 and 85
for msg in diarizing_msgs:
pct = msg.payload["percent"]
assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85"
# Result should be valid
assert result.num_speakers == 1
assert result.speakers == ["SPEAKER_00"]
def test_diarize_threading_error_propagation(monkeypatch):
"""Test that errors from the background thread are properly raised."""
mock_pipeline_obj = MagicMock()
mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed")
service = DiarizeService()
service._pipeline = mock_pipeline_obj
with patch("voice_to_notes.services.diarize.write_message", lambda m: None):
with pytest.raises(RuntimeError, match="Pipeline crashed"):
service.diarize(
request_id="req-1",
file_path="/fake/audio.wav",
audio_duration_sec=30.0,
)