perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
9 changed files with 223 additions and 2 deletions
Showing only changes of commit 67ed69df00 - Show all commits

View File

@@ -3,8 +3,10 @@
from voice_to_notes.ipc.messages import ( from voice_to_notes.ipc.messages import (
IPCMessage, IPCMessage,
error_message, error_message,
partial_segment_message,
progress_message, progress_message,
ready_message, ready_message,
speaker_update_message,
) )
@@ -48,3 +50,16 @@ def test_ready_message():
assert msg.type == "ready" assert msg.type == "ready"
assert msg.id == "system" assert msg.id == "system"
assert "version" in msg.payload assert "version" in msg.payload
def test_partial_segment_message():
msg = partial_segment_message("req-1", {"index": 0, "text": "hello"})
assert msg.type == "pipeline.segment"
assert msg.payload["index"] == 0
assert msg.payload["text"] == "hello"
def test_speaker_update_message():
msg = speaker_update_message("req-1", [{"index": 0, "speaker": "SPEAKER_00"}])
assert msg.type == "pipeline.speaker_update"
assert msg.payload["updates"][0]["speaker"] == "SPEAKER_00"

View File

@@ -88,3 +88,18 @@ def test_merge_results_no_speaker_segments():
result = service._merge_results(transcription, []) result = service._merge_results(transcription, [])
assert result.segments[0].speaker is None assert result.segments[0].speaker is None
def test_speaker_update_generation():
"""Test that speaker updates are generated after merge."""
result = PipelineResult(
segments=[
PipelineSegment(text="Hello", start_ms=0, end_ms=1000, speaker="SPEAKER_00"),
PipelineSegment(text="World", start_ms=1000, end_ms=2000, speaker="SPEAKER_01"),
PipelineSegment(text="Foo", start_ms=2000, end_ms=3000, speaker=None),
],
)
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
assert len(updates) == 2
assert updates[0] == {"index": 0, "speaker": "SPEAKER_00"}
assert updates[1] == {"index": 1, "speaker": "SPEAKER_01"}

View File

@@ -1,7 +1,10 @@
"""Tests for transcription service.""" """Tests for transcription service."""
import inspect
from voice_to_notes.services.transcribe import ( from voice_to_notes.services.transcribe import (
SegmentResult, SegmentResult,
TranscribeService,
TranscriptionResult, TranscriptionResult,
WordResult, WordResult,
result_to_payload, result_to_payload,
@@ -49,3 +52,18 @@ def test_result_to_payload_empty():
assert payload["segments"] == [] assert payload["segments"] == []
assert payload["language"] == "" assert payload["language"] == ""
assert payload["duration_ms"] == 0 assert payload["duration_ms"] == 0
def test_on_segment_callback():
"""Test that on_segment callback is invoked with correct SegmentResult and index."""
callback_args = []
def mock_callback(seg: SegmentResult, index: int):
callback_args.append((seg.text, index))
# Test that passing on_segment doesn't break the function signature
# (Full integration test would require mocking WhisperModel)
service = TranscribeService()
# Verify the parameter exists by checking the signature
sig = inspect.signature(service.transcribe)
assert "on_segment" in sig.parameters

View File

@@ -34,6 +34,14 @@ def progress_message(request_id: str, percent: int, stage: str, message: str) ->
) )
def partial_segment_message(request_id: str, segment_data: dict) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.segment", payload=segment_data)
def speaker_update_message(request_id: str, updates: list[dict]) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.speaker_update", payload={"updates": updates})
def error_message(request_id: str, code: str, message: str) -> IPCMessage: def error_message(request_id: str, code: str, message: str) -> IPCMessage:
return IPCMessage( return IPCMessage(
id=request_id, id=request_id,

View File

@@ -7,7 +7,11 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from voice_to_notes.ipc.messages import progress_message from voice_to_notes.ipc.messages import (
partial_segment_message,
progress_message,
speaker_update_message,
)
from voice_to_notes.ipc.protocol import write_message from voice_to_notes.ipc.protocol import write_message
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
from voice_to_notes.services.transcribe import ( from voice_to_notes.services.transcribe import (
@@ -82,6 +86,15 @@ class PipelineService:
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...") progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
) )
def _emit_segment(seg: SegmentResult, index: int) -> None:
write_message(partial_segment_message(request_id, {
"index": index,
"text": seg.text,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
"words": [{"word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence} for w in seg.words],
}))
transcription = self._transcribe_service.transcribe( transcription = self._transcribe_service.transcribe(
request_id=request_id, request_id=request_id,
file_path=file_path, file_path=file_path,
@@ -89,6 +102,7 @@ class PipelineService:
device=device, device=device,
compute_type=compute_type, compute_type=compute_type,
language=language, language=language,
on_segment=_emit_segment,
) )
if skip_diarization: if skip_diarization:
@@ -140,6 +154,10 @@ class PipelineService:
flush=True, flush=True,
) )
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
if updates:
write_message(speaker_update_message(request_id, updates))
write_message( write_message(
progress_message(request_id, 100, "done", "Pipeline complete") progress_message(request_id, 100, "done", "Pipeline complete")
) )

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import sys import sys
import time import time
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -90,6 +91,7 @@ class TranscribeService:
device: str = "cpu", device: str = "cpu",
compute_type: str = "int8", compute_type: str = "int8",
language: str | None = None, language: str | None = None,
on_segment: Callable[[SegmentResult, int], None] | None = None,
) -> TranscriptionResult: ) -> TranscriptionResult:
"""Transcribe an audio file with word-level timestamps. """Transcribe an audio file with word-level timestamps.
@@ -145,6 +147,9 @@ class TranscribeService:
) )
) )
if on_segment:
on_segment(result.segments[-1], segment_count - 1)
# Send progress every few segments # Send progress every few segments
if segment_count % 5 == 0: if segment_count % 5 == 0:
write_message( write_message(

View File

@@ -1,4 +1,5 @@
use serde_json::{json, Value}; use serde_json::{json, Value};
use tauri::{AppHandle, Emitter};
use crate::sidecar::messages::IPCMessage; use crate::sidecar::messages::IPCMessage;
use crate::sidecar::sidecar; use crate::sidecar::sidecar;
@@ -42,6 +43,7 @@ pub fn transcribe_file(
/// Run the full transcription + diarization pipeline via the Python sidecar. /// Run the full transcription + diarization pipeline via the Python sidecar.
#[tauri::command] #[tauri::command]
pub fn run_pipeline( pub fn run_pipeline(
app: AppHandle,
file_path: String, file_path: String,
model: Option<String>, model: Option<String>,
device: Option<String>, device: Option<String>,
@@ -71,7 +73,14 @@ pub fn run_pipeline(
}), }),
); );
let response = manager.send_and_receive(&msg)?; let response = manager.send_and_receive_with_progress(&msg, |msg| {
let event_name = match msg.msg_type.as_str() {
"pipeline.segment" => "pipeline-segment",
"pipeline.speaker_update" => "pipeline-speaker-update",
_ => "pipeline-progress",
};
let _ = app.emit(event_name, &msg.payload);
})?;
if response.msg_type == "error" { if response.msg_type == "error" {
return Err(format!( return Err(format!(

View File

@@ -165,6 +165,70 @@ impl SidecarManager {
} }
} }
/// Send a message and receive the response, calling a callback for intermediate messages.
/// Intermediate messages include progress, pipeline.segment, and pipeline.speaker_update.
pub fn send_and_receive_with_progress<F>(
&self,
msg: &IPCMessage,
on_intermediate: F,
) -> Result<IPCMessage, String>
where
F: Fn(&IPCMessage),
{
// Write to stdin
{
let mut stdin_guard = self.stdin.lock().map_err(|e| e.to_string())?;
if let Some(ref mut stdin) = *stdin_guard {
let json = serde_json::to_string(msg).map_err(|e| e.to_string())?;
stdin
.write_all(json.as_bytes())
.map_err(|e| format!("Write error: {e}"))?;
stdin
.write_all(b"\n")
.map_err(|e| format!("Write error: {e}"))?;
stdin.flush().map_err(|e| format!("Flush error: {e}"))?;
} else {
return Err("Sidecar stdin not available".to_string());
}
}
// Read from stdout
{
let mut reader_guard = self.reader.lock().map_err(|e| e.to_string())?;
if let Some(ref mut reader) = *reader_guard {
let mut line = String::new();
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.map_err(|e| format!("Read error: {e}"))?;
if bytes_read == 0 {
return Err("Sidecar closed stdout".to_string());
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let response: IPCMessage = serde_json::from_str(trimmed)
.map_err(|e| format!("Parse error: {e}"))?;
// Forward intermediate messages via callback, return the final result/error
let is_intermediate = matches!(
response.msg_type.as_str(),
"progress" | "pipeline.segment" | "pipeline.speaker_update"
);
if is_intermediate {
on_intermediate(&response);
} else {
return Ok(response);
}
}
} else {
Err("Sidecar stdout not available".to_string())
}
}
}
/// Stop the sidecar process. /// Stop the sidecar process.
pub fn stop(&self) -> Result<(), String> { pub fn stop(&self) -> Result<(), String> {
// Drop stdin to signal EOF // Drop stdin to signal EOF

View File

@@ -1,5 +1,6 @@
<script lang="ts"> <script lang="ts">
import { invoke, convertFileSrc } from '@tauri-apps/api/core'; import { invoke, convertFileSrc } from '@tauri-apps/api/core';
import { listen } from '@tauri-apps/api/event';
import { open, save } from '@tauri-apps/plugin-dialog'; import { open, save } from '@tauri-apps/plugin-dialog';
import WaveformPlayer from '$lib/components/WaveformPlayer.svelte'; import WaveformPlayer from '$lib/components/WaveformPlayer.svelte';
import TranscriptEditor from '$lib/components/TranscriptEditor.svelte'; import TranscriptEditor from '$lib/components/TranscriptEditor.svelte';
@@ -85,11 +86,77 @@
audioUrl = convertFileSrc(filePath); audioUrl = convertFileSrc(filePath);
waveformPlayer?.loadAudio(audioUrl); waveformPlayer?.loadAudio(audioUrl);
// Clear previous results
segments.set([]);
speakers.set([]);
// Start pipeline (transcription + diarization) // Start pipeline (transcription + diarization)
isTranscribing = true; isTranscribing = true;
transcriptionProgress = 0; transcriptionProgress = 0;
transcriptionStage = 'Starting...'; transcriptionStage = 'Starting...';
const unlistenSegment = await listen<{
index: number;
text: string;
start_ms: number;
end_ms: number;
words: Array<{ word: string; start_ms: number; end_ms: number; confidence: number }>;
}>('pipeline-segment', (event) => {
const seg = event.payload;
const newSeg: Segment = {
id: `seg-${seg.index}`,
project_id: '',
media_file_id: '',
speaker_id: null,
start_ms: seg.start_ms,
end_ms: seg.end_ms,
text: seg.text,
original_text: null,
confidence: null,
is_edited: false,
edited_at: null,
segment_index: seg.index,
words: seg.words.map((w, widx) => ({
id: `word-${seg.index}-${widx}`,
segment_id: `seg-${seg.index}`,
word: w.word,
start_ms: w.start_ms,
end_ms: w.end_ms,
confidence: w.confidence,
word_index: widx,
})),
};
segments.update(segs => [...segs, newSeg]);
});
const unlistenSpeaker = await listen<{
updates: Array<{ index: number; speaker: string }>;
}>('pipeline-speaker-update', (event) => {
const { updates } = event.payload;
// Build speakers from unique labels
const uniqueLabels = [...new Set(updates.map(u => u.speaker))].sort();
const newSpeakers: Speaker[] = uniqueLabels.map((label, idx) => ({
id: `speaker-${idx}`,
project_id: '',
label,
display_name: null,
color: speakerColors[idx % speakerColors.length],
}));
speakers.set(newSpeakers);
// Update existing segments with speaker assignments
const speakerLookup = new Map(newSpeakers.map(s => [s.label, s.id]));
segments.update(segs =>
segs.map((seg, i) => {
const update = updates.find(u => u.index === i);
if (update) {
return { ...seg, speaker_id: speakerLookup.get(update.speaker) ?? null };
}
return seg;
})
);
});
try { try {
const result = await invoke<{ const result = await invoke<{
segments: Array<{ segments: Array<{
@@ -159,6 +226,8 @@
console.error('Pipeline failed:', err); console.error('Pipeline failed:', err);
alert(`Pipeline failed: ${err}`); alert(`Pipeline failed: ${err}`);
} finally { } finally {
unlistenSegment();
unlistenSpeaker();
isTranscribing = false; isTranscribing = false;
} }
} }