Stream transcript segments to frontend as they are transcribed
Send each segment to the frontend immediately after transcription via a new pipeline.segment IPC message, then send speaker assignments as a batch pipeline.speaker_update message after diarization completes. This lets the UI display segments progressively instead of waiting for the entire pipeline to finish. Changes: - Add partial_segment_message and speaker_update_message IPC factories - Add on_segment callback parameter to TranscribeService.transcribe() - Emit partial segments and speaker updates from PipelineService.run() - Add send_and_receive_with_progress to SidecarManager (Rust) - Route pipeline.segment/speaker_update events in run_pipeline command - Listen for streaming events in Svelte frontend (+page.svelte) - Add tests for new message types, callback signature, and update logic Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,8 +3,10 @@
|
||||
from voice_to_notes.ipc.messages import (
|
||||
IPCMessage,
|
||||
error_message,
|
||||
partial_segment_message,
|
||||
progress_message,
|
||||
ready_message,
|
||||
speaker_update_message,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,3 +50,16 @@ def test_ready_message():
|
||||
assert msg.type == "ready"
|
||||
assert msg.id == "system"
|
||||
assert "version" in msg.payload
|
||||
|
||||
|
||||
def test_partial_segment_message():
|
||||
msg = partial_segment_message("req-1", {"index": 0, "text": "hello"})
|
||||
assert msg.type == "pipeline.segment"
|
||||
assert msg.payload["index"] == 0
|
||||
assert msg.payload["text"] == "hello"
|
||||
|
||||
|
||||
def test_speaker_update_message():
|
||||
msg = speaker_update_message("req-1", [{"index": 0, "speaker": "SPEAKER_00"}])
|
||||
assert msg.type == "pipeline.speaker_update"
|
||||
assert msg.payload["updates"][0]["speaker"] == "SPEAKER_00"
|
||||
|
||||
@@ -88,3 +88,18 @@ def test_merge_results_no_speaker_segments():
|
||||
|
||||
result = service._merge_results(transcription, [])
|
||||
assert result.segments[0].speaker is None
|
||||
|
||||
|
||||
def test_speaker_update_generation():
|
||||
"""Test that speaker updates are generated after merge."""
|
||||
result = PipelineResult(
|
||||
segments=[
|
||||
PipelineSegment(text="Hello", start_ms=0, end_ms=1000, speaker="SPEAKER_00"),
|
||||
PipelineSegment(text="World", start_ms=1000, end_ms=2000, speaker="SPEAKER_01"),
|
||||
PipelineSegment(text="Foo", start_ms=2000, end_ms=3000, speaker=None),
|
||||
],
|
||||
)
|
||||
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
|
||||
assert len(updates) == 2
|
||||
assert updates[0] == {"index": 0, "speaker": "SPEAKER_00"}
|
||||
assert updates[1] == {"index": 1, "speaker": "SPEAKER_01"}
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
"""Tests for transcription service."""
|
||||
|
||||
import inspect
|
||||
|
||||
from voice_to_notes.services.transcribe import (
|
||||
SegmentResult,
|
||||
TranscribeService,
|
||||
TranscriptionResult,
|
||||
WordResult,
|
||||
result_to_payload,
|
||||
@@ -49,3 +52,18 @@ def test_result_to_payload_empty():
|
||||
assert payload["segments"] == []
|
||||
assert payload["language"] == ""
|
||||
assert payload["duration_ms"] == 0
|
||||
|
||||
|
||||
def test_on_segment_callback():
|
||||
"""Test that on_segment callback is invoked with correct SegmentResult and index."""
|
||||
callback_args = []
|
||||
|
||||
def mock_callback(seg: SegmentResult, index: int):
|
||||
callback_args.append((seg.text, index))
|
||||
|
||||
# Test that passing on_segment doesn't break the function signature
|
||||
# (Full integration test would require mocking WhisperModel)
|
||||
service = TranscribeService()
|
||||
# Verify the parameter exists by checking the signature
|
||||
sig = inspect.signature(service.transcribe)
|
||||
assert "on_segment" in sig.parameters
|
||||
|
||||
@@ -34,6 +34,14 @@ def progress_message(request_id: str, percent: int, stage: str, message: str) ->
|
||||
)
|
||||
|
||||
|
||||
def partial_segment_message(request_id: str, segment_data: dict) -> IPCMessage:
|
||||
return IPCMessage(id=request_id, type="pipeline.segment", payload=segment_data)
|
||||
|
||||
|
||||
def speaker_update_message(request_id: str, updates: list[dict]) -> IPCMessage:
|
||||
return IPCMessage(id=request_id, type="pipeline.speaker_update", payload={"updates": updates})
|
||||
|
||||
|
||||
def error_message(request_id: str, code: str, message: str) -> IPCMessage:
|
||||
return IPCMessage(
|
||||
id=request_id,
|
||||
|
||||
@@ -7,7 +7,11 @@ import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from voice_to_notes.ipc.messages import progress_message
|
||||
from voice_to_notes.ipc.messages import (
|
||||
partial_segment_message,
|
||||
progress_message,
|
||||
speaker_update_message,
|
||||
)
|
||||
from voice_to_notes.ipc.protocol import write_message
|
||||
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
|
||||
from voice_to_notes.services.transcribe import (
|
||||
@@ -82,6 +86,15 @@ class PipelineService:
|
||||
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
|
||||
)
|
||||
|
||||
def _emit_segment(seg: SegmentResult, index: int) -> None:
|
||||
write_message(partial_segment_message(request_id, {
|
||||
"index": index,
|
||||
"text": seg.text,
|
||||
"start_ms": seg.start_ms,
|
||||
"end_ms": seg.end_ms,
|
||||
"words": [{"word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence} for w in seg.words],
|
||||
}))
|
||||
|
||||
transcription = self._transcribe_service.transcribe(
|
||||
request_id=request_id,
|
||||
file_path=file_path,
|
||||
@@ -89,6 +102,7 @@ class PipelineService:
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
on_segment=_emit_segment,
|
||||
)
|
||||
|
||||
if skip_diarization:
|
||||
@@ -140,6 +154,10 @@ class PipelineService:
|
||||
flush=True,
|
||||
)
|
||||
|
||||
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
|
||||
if updates:
|
||||
write_message(speaker_update_message(request_id, updates))
|
||||
|
||||
write_message(
|
||||
progress_message(request_id, 100, "done", "Pipeline complete")
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -90,6 +91,7 @@ class TranscribeService:
|
||||
device: str = "cpu",
|
||||
compute_type: str = "int8",
|
||||
language: str | None = None,
|
||||
on_segment: Callable[[SegmentResult, int], None] | None = None,
|
||||
) -> TranscriptionResult:
|
||||
"""Transcribe an audio file with word-level timestamps.
|
||||
|
||||
@@ -145,6 +147,9 @@ class TranscribeService:
|
||||
)
|
||||
)
|
||||
|
||||
if on_segment:
|
||||
on_segment(result.segments[-1], segment_count - 1)
|
||||
|
||||
# Send progress every few segments
|
||||
if segment_count % 5 == 0:
|
||||
write_message(
|
||||
|
||||
Reference in New Issue
Block a user