"""Tests for transcription service.""" import inspect from voice_to_notes.services.transcribe import ( SegmentResult, TranscribeService, TranscriptionResult, WordResult, result_to_payload, ) def test_result_to_payload(): """Test converting TranscriptionResult to IPC payload.""" result = TranscriptionResult( segments=[ SegmentResult( text="hello world", start_ms=0, end_ms=2000, words=[ WordResult(word="hello", start_ms=0, end_ms=500, confidence=0.95), WordResult(word="world", start_ms=600, end_ms=2000, confidence=0.92), ], ), ], language="en", language_probability=0.98, duration_ms=2000, ) payload = result_to_payload(result) assert payload["language"] == "en" assert payload["duration_ms"] == 2000 assert len(payload["segments"]) == 1 seg = payload["segments"][0] assert seg["text"] == "hello world" assert seg["start_ms"] == 0 assert seg["end_ms"] == 2000 assert len(seg["words"]) == 2 assert seg["words"][0]["word"] == "hello" assert seg["words"][0]["confidence"] == 0.95 def test_result_to_payload_empty(): """Test empty transcription result.""" result = TranscriptionResult() payload = result_to_payload(result) assert payload["segments"] == [] assert payload["language"] == "" assert payload["duration_ms"] == 0 def test_on_segment_callback(): """Test that on_segment callback is invoked with correct SegmentResult and index.""" callback_args = [] def mock_callback(seg: SegmentResult, index: int): callback_args.append((seg.text, index)) # Test that passing on_segment doesn't break the function signature # (Full integration test would require mocking WhisperModel) service = TranscribeService() # Verify the parameter exists by checking the signature sig = inspect.signature(service.transcribe) assert "on_segment" in sig.parameters def test_progress_every_segment(monkeypatch): """Verify a progress message is sent for every segment, not just every 5th.""" from unittest.mock import MagicMock, patch from voice_to_notes.services.transcribe import TranscribeService # Mock WhisperModel mock_model = MagicMock() # Create mock segments (8 of them to test > 5) mock_segments = [] for i in range(8): seg = MagicMock() seg.start = i * 1.0 seg.end = (i + 1) * 1.0 seg.text = f"Segment {i}" seg.words = [] mock_segments.append(seg) # Mock info object mock_info = MagicMock() mock_info.language = "en" mock_info.language_probability = 0.99 mock_info.duration = 8.0 mock_model.transcribe.return_value = (iter(mock_segments), mock_info) # Track write_message calls written_messages = [] def mock_write(msg): written_messages.append(msg) service = TranscribeService() service._model = mock_model service._current_model_name = "base" service._current_device = "cpu" service._current_compute_type = "int8" with patch("voice_to_notes.services.transcribe.write_message", mock_write): service.transcribe("req-1", "/fake/audio.wav") # Filter for "transcribing" stage progress messages transcribing_msgs = [ m for m in written_messages if m.type == "progress" and m.payload.get("stage") == "transcribing" ] # Should have one per segment (8) + the initial "Starting transcription..." message # The initial "Starting transcription..." is also stage "transcribing" — so 8 + 1 = 9 assert len(transcribing_msgs) >= 8, ( f"Expected at least 8 transcribing progress messages (one per segment), got {len(transcribing_msgs)}" )