"""Tests for diarization service data structures and payload conversion.""" import time from unittest.mock import MagicMock, patch import pytest from voice_to_notes.services.diarize import ( DiarizationResult, DiarizeService, SpeakerSegment, diarization_to_payload, ) def test_diarization_to_payload(): result = DiarizationResult( speaker_segments=[ SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=5000), SpeakerSegment(speaker="SPEAKER_01", start_ms=5000, end_ms=10000), SpeakerSegment(speaker="SPEAKER_00", start_ms=10000, end_ms=15000), ], num_speakers=2, speakers=["SPEAKER_00", "SPEAKER_01"], ) payload = diarization_to_payload(result) assert payload["num_speakers"] == 2 assert len(payload["speaker_segments"]) == 3 assert payload["speakers"] == ["SPEAKER_00", "SPEAKER_01"] assert payload["speaker_segments"][0]["speaker"] == "SPEAKER_00" assert payload["speaker_segments"][1]["start_ms"] == 5000 def test_diarization_to_payload_empty(): result = DiarizationResult() payload = diarization_to_payload(result) assert payload["num_speakers"] == 0 assert payload["speaker_segments"] == [] assert payload["speakers"] == [] def test_diarize_threading_progress(monkeypatch): """Test that diarization emits progress while running in background thread.""" # Track written messages written_messages = [] def mock_write(msg): written_messages.append(msg) # Mock pipeline that takes ~5 seconds def slow_pipeline(file_path, **kwargs): time.sleep(5) # Return a mock diarization result (use spec=object to prevent # hasattr returning True for speaker_diarization) mock_result = MagicMock(spec=[]) mock_track = MagicMock() mock_track.start = 0.0 mock_track.end = 5.0 mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")]) return mock_result mock_pipeline_obj = MagicMock() mock_pipeline_obj.side_effect = slow_pipeline service = DiarizeService() service._pipeline = mock_pipeline_obj with patch("voice_to_notes.services.diarize.write_message", mock_write): result = service.diarize( request_id="req-1", file_path="/fake/audio.wav", audio_duration_sec=60.0, ) # Filter for diarizing progress messages (not loading_diarization or done) diarizing_msgs = [ m for m in written_messages if m.type == "progress" and m.payload.get("stage") == "diarizing" and "elapsed" in m.payload.get("message", "") ] # Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages) assert len(diarizing_msgs) >= 1, ( f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}" ) # Progress percent should be between 20 and 85 for msg in diarizing_msgs: pct = msg.payload["percent"] assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85" # Result should be valid assert result.num_speakers == 1 assert result.speakers == ["SPEAKER_00"] def test_diarize_threading_error_propagation(monkeypatch): """Test that errors from the background thread are properly raised.""" mock_pipeline_obj = MagicMock() mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed") service = DiarizeService() service._pipeline = mock_pipeline_obj with patch("voice_to_notes.services.diarize.write_message", lambda m: None): with pytest.raises(RuntimeError, match="Pipeline crashed"): service.diarize( request_id="req-1", file_path="/fake/audio.wav", audio_duration_sec=30.0, )