Merge perf/stream-segments: streaming partial transcript segments and speaker updates

This commit is contained in:
Claude
2026-03-20 13:51:51 -07:00
14 changed files with 245 additions and 4 deletions

View File

@@ -3,8 +3,10 @@
from voice_to_notes.ipc.messages import (
IPCMessage,
error_message,
partial_segment_message,
progress_message,
ready_message,
speaker_update_message,
)
@@ -48,3 +50,16 @@ def test_ready_message():
assert msg.type == "ready"
assert msg.id == "system"
assert "version" in msg.payload
def test_partial_segment_message():
msg = partial_segment_message("req-1", {"index": 0, "text": "hello"})
assert msg.type == "pipeline.segment"
assert msg.payload["index"] == 0
assert msg.payload["text"] == "hello"
def test_speaker_update_message():
msg = speaker_update_message("req-1", [{"index": 0, "speaker": "SPEAKER_00"}])
assert msg.type == "pipeline.speaker_update"
assert msg.payload["updates"][0]["speaker"] == "SPEAKER_00"

View File

@@ -88,3 +88,18 @@ def test_merge_results_no_speaker_segments():
result = service._merge_results(transcription, [])
assert result.segments[0].speaker is None
def test_speaker_update_generation():
"""Test that speaker updates are generated after merge."""
result = PipelineResult(
segments=[
PipelineSegment(text="Hello", start_ms=0, end_ms=1000, speaker="SPEAKER_00"),
PipelineSegment(text="World", start_ms=1000, end_ms=2000, speaker="SPEAKER_01"),
PipelineSegment(text="Foo", start_ms=2000, end_ms=3000, speaker=None),
],
)
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
assert len(updates) == 2
assert updates[0] == {"index": 0, "speaker": "SPEAKER_00"}
assert updates[1] == {"index": 1, "speaker": "SPEAKER_01"}

View File

@@ -1,7 +1,10 @@
"""Tests for transcription service."""
import inspect
from voice_to_notes.services.transcribe import (
SegmentResult,
TranscribeService,
TranscriptionResult,
WordResult,
result_to_payload,
@@ -49,3 +52,18 @@ def test_result_to_payload_empty():
assert payload["segments"] == []
assert payload["language"] == ""
assert payload["duration_ms"] == 0
def test_on_segment_callback():
"""Test that on_segment callback is invoked with correct SegmentResult and index."""
callback_args = []
def mock_callback(seg: SegmentResult, index: int):
callback_args.append((seg.text, index))
# Test that passing on_segment doesn't break the function signature
# (Full integration test would require mocking WhisperModel)
service = TranscribeService()
# Verify the parameter exists by checking the signature
sig = inspect.signature(service.transcribe)
assert "on_segment" in sig.parameters

View File

@@ -34,6 +34,14 @@ def progress_message(request_id: str, percent: int, stage: str, message: str) ->
)
def partial_segment_message(request_id: str, segment_data: dict) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.segment", payload=segment_data)
def speaker_update_message(request_id: str, updates: list[dict]) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.speaker_update", payload={"updates": updates})
def error_message(request_id: str, code: str, message: str) -> IPCMessage:
return IPCMessage(
id=request_id,

View File

@@ -7,7 +7,11 @@ import time
from dataclasses import dataclass, field
from typing import Any
from voice_to_notes.ipc.messages import progress_message
from voice_to_notes.ipc.messages import (
partial_segment_message,
progress_message,
speaker_update_message,
)
from voice_to_notes.ipc.protocol import write_message
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
from voice_to_notes.services.transcribe import (
@@ -83,6 +87,15 @@ class PipelineService:
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
)
def _emit_segment(seg: SegmentResult, index: int) -> None:
write_message(partial_segment_message(request_id, {
"index": index,
"text": seg.text,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
"words": [{"word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence} for w in seg.words],
}))
transcription = self._transcribe_service.transcribe(
request_id=request_id,
file_path=file_path,
@@ -90,6 +103,7 @@ class PipelineService:
device=device,
compute_type=compute_type,
language=language,
on_segment=_emit_segment,
)
if skip_diarization:
@@ -174,6 +188,10 @@ class PipelineService:
flush=True,
)
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
if updates:
write_message(speaker_update_message(request_id, updates))
write_message(
progress_message(request_id, 100, "done", "Pipeline complete")
)

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import sys
import time
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any
@@ -90,6 +91,7 @@ class TranscribeService:
device: str = "cpu",
compute_type: str = "int8",
language: str | None = None,
on_segment: Callable[[SegmentResult, int], None] | None = None,
) -> TranscriptionResult:
"""Transcribe an audio file with word-level timestamps.
@@ -145,6 +147,9 @@ class TranscribeService:
)
)
if on_segment:
on_segment(result.segments[-1], segment_count - 1)
# Send progress every few segments
if segment_count % 5 == 0:
write_message(