perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
14 changed files with 245 additions and 4 deletions
Showing only changes of commit c3b6ad38fd - Show all commits

View File

@@ -0,0 +1,21 @@
{
"permissions": {
"allow": [
"Bash(git init:*)",
"Bash(git:*)",
"WebSearch",
"Bash(npm create:*)",
"Bash(cp:*)",
"Bash(npm install:*)",
"Bash(/home/jknapp/.cargo/bin/cargo test:*)",
"Bash(ruff:*)",
"Bash(npm run:*)",
"Bash(npx svelte-check:*)",
"Bash(pip install:*)",
"Bash(python3:*)",
"Bash(/home/jknapp/.cargo/bin/cargo check:*)",
"Bash(cargo check:*)",
"Bash(npm ls:*)"
]
}
}

Submodule .claude/worktrees/agent-a0bd87d1 added at 67ed69df00

Submodule .claude/worktrees/agent-a198b5f8 added at 6eb13bce63

Submodule .claude/worktrees/agent-ad3d6fca added at 03af5a189c

Submodule .claude/worktrees/agent-aefe2597 added at 16f4b57771

View File

@@ -3,8 +3,10 @@
from voice_to_notes.ipc.messages import ( from voice_to_notes.ipc.messages import (
IPCMessage, IPCMessage,
error_message, error_message,
partial_segment_message,
progress_message, progress_message,
ready_message, ready_message,
speaker_update_message,
) )
@@ -48,3 +50,16 @@ def test_ready_message():
assert msg.type == "ready" assert msg.type == "ready"
assert msg.id == "system" assert msg.id == "system"
assert "version" in msg.payload assert "version" in msg.payload
def test_partial_segment_message():
msg = partial_segment_message("req-1", {"index": 0, "text": "hello"})
assert msg.type == "pipeline.segment"
assert msg.payload["index"] == 0
assert msg.payload["text"] == "hello"
def test_speaker_update_message():
msg = speaker_update_message("req-1", [{"index": 0, "speaker": "SPEAKER_00"}])
assert msg.type == "pipeline.speaker_update"
assert msg.payload["updates"][0]["speaker"] == "SPEAKER_00"

View File

@@ -88,3 +88,18 @@ def test_merge_results_no_speaker_segments():
result = service._merge_results(transcription, []) result = service._merge_results(transcription, [])
assert result.segments[0].speaker is None assert result.segments[0].speaker is None
def test_speaker_update_generation():
"""Test that speaker updates are generated after merge."""
result = PipelineResult(
segments=[
PipelineSegment(text="Hello", start_ms=0, end_ms=1000, speaker="SPEAKER_00"),
PipelineSegment(text="World", start_ms=1000, end_ms=2000, speaker="SPEAKER_01"),
PipelineSegment(text="Foo", start_ms=2000, end_ms=3000, speaker=None),
],
)
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
assert len(updates) == 2
assert updates[0] == {"index": 0, "speaker": "SPEAKER_00"}
assert updates[1] == {"index": 1, "speaker": "SPEAKER_01"}

View File

@@ -1,7 +1,10 @@
"""Tests for transcription service.""" """Tests for transcription service."""
import inspect
from voice_to_notes.services.transcribe import ( from voice_to_notes.services.transcribe import (
SegmentResult, SegmentResult,
TranscribeService,
TranscriptionResult, TranscriptionResult,
WordResult, WordResult,
result_to_payload, result_to_payload,
@@ -49,3 +52,18 @@ def test_result_to_payload_empty():
assert payload["segments"] == [] assert payload["segments"] == []
assert payload["language"] == "" assert payload["language"] == ""
assert payload["duration_ms"] == 0 assert payload["duration_ms"] == 0
def test_on_segment_callback():
"""Test that on_segment callback is invoked with correct SegmentResult and index."""
callback_args = []
def mock_callback(seg: SegmentResult, index: int):
callback_args.append((seg.text, index))
# Test that passing on_segment doesn't break the function signature
# (Full integration test would require mocking WhisperModel)
service = TranscribeService()
# Verify the parameter exists by checking the signature
sig = inspect.signature(service.transcribe)
assert "on_segment" in sig.parameters

View File

@@ -34,6 +34,14 @@ def progress_message(request_id: str, percent: int, stage: str, message: str) ->
) )
def partial_segment_message(request_id: str, segment_data: dict) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.segment", payload=segment_data)
def speaker_update_message(request_id: str, updates: list[dict]) -> IPCMessage:
return IPCMessage(id=request_id, type="pipeline.speaker_update", payload={"updates": updates})
def error_message(request_id: str, code: str, message: str) -> IPCMessage: def error_message(request_id: str, code: str, message: str) -> IPCMessage:
return IPCMessage( return IPCMessage(
id=request_id, id=request_id,

View File

@@ -7,7 +7,11 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from voice_to_notes.ipc.messages import progress_message from voice_to_notes.ipc.messages import (
partial_segment_message,
progress_message,
speaker_update_message,
)
from voice_to_notes.ipc.protocol import write_message from voice_to_notes.ipc.protocol import write_message
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
from voice_to_notes.services.transcribe import ( from voice_to_notes.services.transcribe import (
@@ -83,6 +87,15 @@ class PipelineService:
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...") progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
) )
def _emit_segment(seg: SegmentResult, index: int) -> None:
write_message(partial_segment_message(request_id, {
"index": index,
"text": seg.text,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
"words": [{"word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence} for w in seg.words],
}))
transcription = self._transcribe_service.transcribe( transcription = self._transcribe_service.transcribe(
request_id=request_id, request_id=request_id,
file_path=file_path, file_path=file_path,
@@ -90,6 +103,7 @@ class PipelineService:
device=device, device=device,
compute_type=compute_type, compute_type=compute_type,
language=language, language=language,
on_segment=_emit_segment,
) )
if skip_diarization: if skip_diarization:
@@ -174,6 +188,10 @@ class PipelineService:
flush=True, flush=True,
) )
updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker]
if updates:
write_message(speaker_update_message(request_id, updates))
write_message( write_message(
progress_message(request_id, 100, "done", "Pipeline complete") progress_message(request_id, 100, "done", "Pipeline complete")
) )

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import sys import sys
import time import time
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
@@ -90,6 +91,7 @@ class TranscribeService:
device: str = "cpu", device: str = "cpu",
compute_type: str = "int8", compute_type: str = "int8",
language: str | None = None, language: str | None = None,
on_segment: Callable[[SegmentResult, int], None] | None = None,
) -> TranscriptionResult: ) -> TranscriptionResult:
"""Transcribe an audio file with word-level timestamps. """Transcribe an audio file with word-level timestamps.
@@ -145,6 +147,9 @@ class TranscribeService:
) )
) )
if on_segment:
on_segment(result.segments[-1], segment_count - 1)
# Send progress every few segments # Send progress every few segments
if segment_count % 5 == 0: if segment_count % 5 == 0:
write_message( write_message(

View File

@@ -104,8 +104,13 @@ pub fn run_pipeline(
}), }),
); );
let response = manager.send_and_receive_with_progress(&msg, |progress| { let response = manager.send_and_receive_with_progress(&msg, |msg| {
let _ = app.emit("pipeline-progress", &progress.payload); let event_name = match msg.msg_type.as_str() {
"pipeline.segment" => "pipeline-segment",
"pipeline.speaker_update" => "pipeline-speaker-update",
_ => "pipeline-progress",
};
let _ = app.emit(event_name, &msg.payload);
})?; })?;
if response.msg_type == "error" { if response.msg_type == "error" {

View File

@@ -182,6 +182,70 @@ impl SidecarManager {
} }
} }
/// Send a message and receive the response, calling a callback for intermediate messages.
/// Intermediate messages include progress, pipeline.segment, and pipeline.speaker_update.
pub fn send_and_receive_with_progress<F>(
&self,
msg: &IPCMessage,
on_intermediate: F,
) -> Result<IPCMessage, String>
where
F: Fn(&IPCMessage),
{
// Write to stdin
{
let mut stdin_guard = self.stdin.lock().map_err(|e| e.to_string())?;
if let Some(ref mut stdin) = *stdin_guard {
let json = serde_json::to_string(msg).map_err(|e| e.to_string())?;
stdin
.write_all(json.as_bytes())
.map_err(|e| format!("Write error: {e}"))?;
stdin
.write_all(b"\n")
.map_err(|e| format!("Write error: {e}"))?;
stdin.flush().map_err(|e| format!("Flush error: {e}"))?;
} else {
return Err("Sidecar stdin not available".to_string());
}
}
// Read from stdout
{
let mut reader_guard = self.reader.lock().map_err(|e| e.to_string())?;
if let Some(ref mut reader) = *reader_guard {
let mut line = String::new();
loop {
line.clear();
let bytes_read = reader
.read_line(&mut line)
.map_err(|e| format!("Read error: {e}"))?;
if bytes_read == 0 {
return Err("Sidecar closed stdout".to_string());
}
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let response: IPCMessage = serde_json::from_str(trimmed)
.map_err(|e| format!("Parse error: {e}"))?;
// Forward intermediate messages via callback, return the final result/error
let is_intermediate = matches!(
response.msg_type.as_str(),
"progress" | "pipeline.segment" | "pipeline.speaker_update"
);
if is_intermediate {
on_intermediate(&response);
} else {
return Ok(response);
}
}
} else {
Err("Sidecar stdout not available".to_string())
}
}
}
/// Stop the sidecar process. /// Stop the sidecar process.
pub fn stop(&self) -> Result<(), String> { pub fn stop(&self) -> Result<(), String> {
// Drop stdin to signal EOF // Drop stdin to signal EOF

View File

@@ -87,7 +87,11 @@
audioUrl = convertFileSrc(filePath); audioUrl = convertFileSrc(filePath);
waveformPlayer?.loadAudio(audioUrl); waveformPlayer?.loadAudio(audioUrl);
// Start pipeline — show overlay immediately before heavy processing // Clear previous results
segments.set([]);
speakers.set([]);
// Start pipeline (transcription + diarization)
isTranscribing = true; isTranscribing = true;
transcriptionProgress = 0; transcriptionProgress = 0;
transcriptionStage = 'Starting...'; transcriptionStage = 'Starting...';
@@ -109,6 +113,68 @@
if (typeof message === 'string') transcriptionMessage = message; if (typeof message === 'string') transcriptionMessage = message;
}); });
const unlistenSegment = await listen<{
index: number;
text: string;
start_ms: number;
end_ms: number;
words: Array<{ word: string; start_ms: number; end_ms: number; confidence: number }>;
}>('pipeline-segment', (event) => {
const seg = event.payload;
const newSeg: Segment = {
id: `seg-${seg.index}`,
project_id: '',
media_file_id: '',
speaker_id: null,
start_ms: seg.start_ms,
end_ms: seg.end_ms,
text: seg.text,
original_text: null,
confidence: null,
is_edited: false,
edited_at: null,
segment_index: seg.index,
words: seg.words.map((w, widx) => ({
id: `word-${seg.index}-${widx}`,
segment_id: `seg-${seg.index}`,
word: w.word,
start_ms: w.start_ms,
end_ms: w.end_ms,
confidence: w.confidence,
word_index: widx,
})),
};
segments.update(segs => [...segs, newSeg]);
});
const unlistenSpeaker = await listen<{
updates: Array<{ index: number; speaker: string }>;
}>('pipeline-speaker-update', (event) => {
const { updates } = event.payload;
// Build speakers from unique labels
const uniqueLabels = [...new Set(updates.map(u => u.speaker))].sort();
const newSpeakers: Speaker[] = uniqueLabels.map((label, idx) => ({
id: `speaker-${idx}`,
project_id: '',
label,
display_name: null,
color: speakerColors[idx % speakerColors.length],
}));
speakers.set(newSpeakers);
// Update existing segments with speaker assignments
const speakerLookup = new Map(newSpeakers.map(s => [s.label, s.id]));
segments.update(segs =>
segs.map((seg, i) => {
const update = updates.find(u => u.index === i);
if (update) {
return { ...seg, speaker_id: speakerLookup.get(update.speaker) ?? null };
}
return seg;
})
);
});
try { try {
const result = await invoke<{ const result = await invoke<{
segments: Array<{ segments: Array<{
@@ -180,6 +246,8 @@
alert(`Pipeline failed: ${err}`); alert(`Pipeline failed: ${err}`);
} finally { } finally {
unlisten(); unlisten();
unlistenSegment();
unlistenSpeaker();
isTranscribing = false; isTranscribing = false;
} }
} }