perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
2 changed files with 61 additions and 9 deletions
Showing only changes of commit 35af6e9e0c - Show all commits

View File

@@ -67,3 +67,57 @@ def test_on_segment_callback():
# Verify the parameter exists by checking the signature # Verify the parameter exists by checking the signature
sig = inspect.signature(service.transcribe) sig = inspect.signature(service.transcribe)
assert "on_segment" in sig.parameters assert "on_segment" in sig.parameters
def test_progress_every_segment(monkeypatch):
"""Verify a progress message is sent for every segment, not just every 5th."""
from unittest.mock import MagicMock, patch
from voice_to_notes.services.transcribe import TranscribeService
# Mock WhisperModel
mock_model = MagicMock()
# Create mock segments (8 of them to test > 5)
mock_segments = []
for i in range(8):
seg = MagicMock()
seg.start = i * 1.0
seg.end = (i + 1) * 1.0
seg.text = f"Segment {i}"
seg.words = []
mock_segments.append(seg)
# Mock info object
mock_info = MagicMock()
mock_info.language = "en"
mock_info.language_probability = 0.99
mock_info.duration = 8.0
mock_model.transcribe.return_value = (iter(mock_segments), mock_info)
# Track write_message calls
written_messages = []
def mock_write(msg):
written_messages.append(msg)
service = TranscribeService()
service._model = mock_model
service._current_model_name = "base"
service._current_device = "cpu"
service._current_compute_type = "int8"
with patch("voice_to_notes.services.transcribe.write_message", mock_write):
service.transcribe("req-1", "/fake/audio.wav")
# Filter for "transcribing" stage progress messages
transcribing_msgs = [
m for m in written_messages
if m.type == "progress" and m.payload.get("stage") == "transcribing"
]
# Should have one per segment (8) + the initial "Starting transcription..." message
# The initial "Starting transcription..." is also stage "transcribing" — so 8 + 1 = 9
assert len(transcribing_msgs) >= 8, (
f"Expected at least 8 transcribing progress messages (one per segment), got {len(transcribing_msgs)}"
)

View File

@@ -150,14 +150,12 @@ class TranscribeService:
if on_segment: if on_segment:
on_segment(result.segments[-1], segment_count - 1) on_segment(result.segments[-1], segment_count - 1)
# Send progress every few segments
if segment_count % 5 == 0:
write_message( write_message(
progress_message( progress_message(
request_id, request_id,
progress_pct, progress_pct,
"transcribing", "transcribing",
f"Processed {segment_count} segments...", f"Transcribing segment {segment_count} ({progress_pct}% of audio)...",
) )
) )