diff --git a/python/tests/test_transcribe.py b/python/tests/test_transcribe.py index b0af248..3d035b5 100644 --- a/python/tests/test_transcribe.py +++ b/python/tests/test_transcribe.py @@ -67,3 +67,57 @@ def test_on_segment_callback(): # Verify the parameter exists by checking the signature sig = inspect.signature(service.transcribe) assert "on_segment" in sig.parameters + + +def test_progress_every_segment(monkeypatch): + """Verify a progress message is sent for every segment, not just every 5th.""" + from unittest.mock import MagicMock, patch + from voice_to_notes.services.transcribe import TranscribeService + + # Mock WhisperModel + mock_model = MagicMock() + + # Create mock segments (8 of them to test > 5) + mock_segments = [] + for i in range(8): + seg = MagicMock() + seg.start = i * 1.0 + seg.end = (i + 1) * 1.0 + seg.text = f"Segment {i}" + seg.words = [] + mock_segments.append(seg) + + # Mock info object + mock_info = MagicMock() + mock_info.language = "en" + mock_info.language_probability = 0.99 + mock_info.duration = 8.0 + + mock_model.transcribe.return_value = (iter(mock_segments), mock_info) + + # Track write_message calls + written_messages = [] + + def mock_write(msg): + written_messages.append(msg) + + service = TranscribeService() + service._model = mock_model + service._current_model_name = "base" + service._current_device = "cpu" + service._current_compute_type = "int8" + + with patch("voice_to_notes.services.transcribe.write_message", mock_write): + service.transcribe("req-1", "/fake/audio.wav") + + # Filter for "transcribing" stage progress messages + transcribing_msgs = [ + m for m in written_messages + if m.type == "progress" and m.payload.get("stage") == "transcribing" + ] + + # Should have one per segment (8) + the initial "Starting transcription..." message + # The initial "Starting transcription..." is also stage "transcribing" — so 8 + 1 = 9 + assert len(transcribing_msgs) >= 8, ( + f"Expected at least 8 transcribing progress messages (one per segment), got {len(transcribing_msgs)}" + ) diff --git a/python/voice_to_notes/services/transcribe.py b/python/voice_to_notes/services/transcribe.py index 7a0b67c..1ed1904 100644 --- a/python/voice_to_notes/services/transcribe.py +++ b/python/voice_to_notes/services/transcribe.py @@ -150,16 +150,14 @@ class TranscribeService: if on_segment: on_segment(result.segments[-1], segment_count - 1) - # Send progress every few segments - if segment_count % 5 == 0: - write_message( - progress_message( - request_id, - progress_pct, - "transcribing", - f"Processed {segment_count} segments...", - ) + write_message( + progress_message( + request_id, + progress_pct, + "transcribing", + f"Transcribing segment {segment_count} ({progress_pct}% of audio)...", ) + ) elapsed = time.time() - start_time print(