2026-02-26 15:53:09 -08:00
|
|
|
"""Tests for transcription service."""
|
|
|
|
|
|
2026-03-20 13:47:57 -07:00
|
|
|
import inspect
|
|
|
|
|
|
2026-02-26 15:53:09 -08:00
|
|
|
from voice_to_notes.services.transcribe import (
|
|
|
|
|
SegmentResult,
|
2026-03-20 13:47:57 -07:00
|
|
|
TranscribeService,
|
2026-02-26 15:53:09 -08:00
|
|
|
TranscriptionResult,
|
|
|
|
|
WordResult,
|
|
|
|
|
result_to_payload,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_to_payload():
|
|
|
|
|
"""Test converting TranscriptionResult to IPC payload."""
|
|
|
|
|
result = TranscriptionResult(
|
|
|
|
|
segments=[
|
|
|
|
|
SegmentResult(
|
|
|
|
|
text="hello world",
|
|
|
|
|
start_ms=0,
|
|
|
|
|
end_ms=2000,
|
|
|
|
|
words=[
|
|
|
|
|
WordResult(word="hello", start_ms=0, end_ms=500, confidence=0.95),
|
|
|
|
|
WordResult(word="world", start_ms=600, end_ms=2000, confidence=0.92),
|
|
|
|
|
],
|
|
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
language="en",
|
|
|
|
|
language_probability=0.98,
|
|
|
|
|
duration_ms=2000,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
payload = result_to_payload(result)
|
|
|
|
|
|
|
|
|
|
assert payload["language"] == "en"
|
|
|
|
|
assert payload["duration_ms"] == 2000
|
|
|
|
|
assert len(payload["segments"]) == 1
|
|
|
|
|
|
|
|
|
|
seg = payload["segments"][0]
|
|
|
|
|
assert seg["text"] == "hello world"
|
|
|
|
|
assert seg["start_ms"] == 0
|
|
|
|
|
assert seg["end_ms"] == 2000
|
|
|
|
|
assert len(seg["words"]) == 2
|
|
|
|
|
assert seg["words"][0]["word"] == "hello"
|
|
|
|
|
assert seg["words"][0]["confidence"] == 0.95
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_to_payload_empty():
|
|
|
|
|
"""Test empty transcription result."""
|
|
|
|
|
result = TranscriptionResult()
|
|
|
|
|
payload = result_to_payload(result)
|
|
|
|
|
assert payload["segments"] == []
|
|
|
|
|
assert payload["language"] == ""
|
|
|
|
|
assert payload["duration_ms"] == 0
|
2026-03-20 13:47:57 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_on_segment_callback():
|
|
|
|
|
"""Test that on_segment callback is invoked with correct SegmentResult and index."""
|
|
|
|
|
callback_args = []
|
|
|
|
|
|
|
|
|
|
def mock_callback(seg: SegmentResult, index: int):
|
|
|
|
|
callback_args.append((seg.text, index))
|
|
|
|
|
|
|
|
|
|
# Test that passing on_segment doesn't break the function signature
|
|
|
|
|
# (Full integration test would require mocking WhisperModel)
|
|
|
|
|
service = TranscribeService()
|
|
|
|
|
# Verify the parameter exists by checking the signature
|
|
|
|
|
sig = inspect.signature(service.transcribe)
|
|
|
|
|
assert "on_segment" in sig.parameters
|
2026-03-20 13:52:18 -07:00
|
|
|
|
|
|
|
|
|
2026-03-20 13:49:14 -07:00
|
|
|
def test_progress_every_segment(monkeypatch):
|
|
|
|
|
"""Verify a progress message is sent for every segment, not just every 5th."""
|
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
|
from voice_to_notes.services.transcribe import TranscribeService
|
|
|
|
|
|
|
|
|
|
# Mock WhisperModel
|
|
|
|
|
mock_model = MagicMock()
|
|
|
|
|
|
|
|
|
|
# Create mock segments (8 of them to test > 5)
|
|
|
|
|
mock_segments = []
|
|
|
|
|
for i in range(8):
|
|
|
|
|
seg = MagicMock()
|
|
|
|
|
seg.start = i * 1.0
|
|
|
|
|
seg.end = (i + 1) * 1.0
|
|
|
|
|
seg.text = f"Segment {i}"
|
|
|
|
|
seg.words = []
|
|
|
|
|
mock_segments.append(seg)
|
|
|
|
|
|
|
|
|
|
# Mock info object
|
|
|
|
|
mock_info = MagicMock()
|
|
|
|
|
mock_info.language = "en"
|
|
|
|
|
mock_info.language_probability = 0.99
|
|
|
|
|
mock_info.duration = 8.0
|
|
|
|
|
|
|
|
|
|
mock_model.transcribe.return_value = (iter(mock_segments), mock_info)
|
|
|
|
|
|
|
|
|
|
# Track write_message calls
|
|
|
|
|
written_messages = []
|
|
|
|
|
|
|
|
|
|
def mock_write(msg):
|
|
|
|
|
written_messages.append(msg)
|
|
|
|
|
|
|
|
|
|
service = TranscribeService()
|
|
|
|
|
service._model = mock_model
|
|
|
|
|
service._current_model_name = "base"
|
|
|
|
|
service._current_device = "cpu"
|
|
|
|
|
service._current_compute_type = "int8"
|
|
|
|
|
|
|
|
|
|
with patch("voice_to_notes.services.transcribe.write_message", mock_write):
|
|
|
|
|
service.transcribe("req-1", "/fake/audio.wav")
|
|
|
|
|
|
|
|
|
|
# Filter for "transcribing" stage progress messages
|
|
|
|
|
transcribing_msgs = [
|
|
|
|
|
m for m in written_messages
|
|
|
|
|
if m.type == "progress" and m.payload.get("stage") == "transcribing"
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Should have one per segment (8) + the initial "Starting transcription..." message
|
|
|
|
|
# The initial "Starting transcription..." is also stage "transcribing" — so 8 + 1 = 9
|
|
|
|
|
assert len(transcribing_msgs) >= 8, (
|
|
|
|
|
f"Expected at least 8 transcribing progress messages (one per segment), got {len(transcribing_msgs)}"
|
|
|
|
|
)
|