diff --git a/python/tests/test_messages.py b/python/tests/test_messages.py index 12b71db..eb2a75d 100644 --- a/python/tests/test_messages.py +++ b/python/tests/test_messages.py @@ -3,8 +3,10 @@ from voice_to_notes.ipc.messages import ( IPCMessage, error_message, + partial_segment_message, progress_message, ready_message, + speaker_update_message, ) @@ -48,3 +50,16 @@ def test_ready_message(): assert msg.type == "ready" assert msg.id == "system" assert "version" in msg.payload + + +def test_partial_segment_message(): + msg = partial_segment_message("req-1", {"index": 0, "text": "hello"}) + assert msg.type == "pipeline.segment" + assert msg.payload["index"] == 0 + assert msg.payload["text"] == "hello" + + +def test_speaker_update_message(): + msg = speaker_update_message("req-1", [{"index": 0, "speaker": "SPEAKER_00"}]) + assert msg.type == "pipeline.speaker_update" + assert msg.payload["updates"][0]["speaker"] == "SPEAKER_00" diff --git a/python/tests/test_pipeline.py b/python/tests/test_pipeline.py index a1d3fef..33789aa 100644 --- a/python/tests/test_pipeline.py +++ b/python/tests/test_pipeline.py @@ -88,3 +88,18 @@ def test_merge_results_no_speaker_segments(): result = service._merge_results(transcription, []) assert result.segments[0].speaker is None + + +def test_speaker_update_generation(): + """Test that speaker updates are generated after merge.""" + result = PipelineResult( + segments=[ + PipelineSegment(text="Hello", start_ms=0, end_ms=1000, speaker="SPEAKER_00"), + PipelineSegment(text="World", start_ms=1000, end_ms=2000, speaker="SPEAKER_01"), + PipelineSegment(text="Foo", start_ms=2000, end_ms=3000, speaker=None), + ], + ) + updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker] + assert len(updates) == 2 + assert updates[0] == {"index": 0, "speaker": "SPEAKER_00"} + assert updates[1] == {"index": 1, "speaker": "SPEAKER_01"} diff --git a/python/tests/test_transcribe.py b/python/tests/test_transcribe.py index b9e4220..b0af248 100644 --- a/python/tests/test_transcribe.py +++ b/python/tests/test_transcribe.py @@ -1,7 +1,10 @@ """Tests for transcription service.""" +import inspect + from voice_to_notes.services.transcribe import ( SegmentResult, + TranscribeService, TranscriptionResult, WordResult, result_to_payload, @@ -49,3 +52,18 @@ def test_result_to_payload_empty(): assert payload["segments"] == [] assert payload["language"] == "" assert payload["duration_ms"] == 0 + + +def test_on_segment_callback(): + """Test that on_segment callback is invoked with correct SegmentResult and index.""" + callback_args = [] + + def mock_callback(seg: SegmentResult, index: int): + callback_args.append((seg.text, index)) + + # Test that passing on_segment doesn't break the function signature + # (Full integration test would require mocking WhisperModel) + service = TranscribeService() + # Verify the parameter exists by checking the signature + sig = inspect.signature(service.transcribe) + assert "on_segment" in sig.parameters diff --git a/python/voice_to_notes/ipc/messages.py b/python/voice_to_notes/ipc/messages.py index 6abc3d5..4e08df9 100644 --- a/python/voice_to_notes/ipc/messages.py +++ b/python/voice_to_notes/ipc/messages.py @@ -34,6 +34,14 @@ def progress_message(request_id: str, percent: int, stage: str, message: str) -> ) +def partial_segment_message(request_id: str, segment_data: dict) -> IPCMessage: + return IPCMessage(id=request_id, type="pipeline.segment", payload=segment_data) + + +def speaker_update_message(request_id: str, updates: list[dict]) -> IPCMessage: + return IPCMessage(id=request_id, type="pipeline.speaker_update", payload={"updates": updates}) + + def error_message(request_id: str, code: str, message: str) -> IPCMessage: return IPCMessage( id=request_id, diff --git a/python/voice_to_notes/services/pipeline.py b/python/voice_to_notes/services/pipeline.py index 2d1f66b..93204c3 100644 --- a/python/voice_to_notes/services/pipeline.py +++ b/python/voice_to_notes/services/pipeline.py @@ -7,7 +7,11 @@ import time from dataclasses import dataclass, field from typing import Any -from voice_to_notes.ipc.messages import progress_message +from voice_to_notes.ipc.messages import ( + partial_segment_message, + progress_message, + speaker_update_message, +) from voice_to_notes.ipc.protocol import write_message from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment from voice_to_notes.services.transcribe import ( @@ -82,6 +86,15 @@ class PipelineService: progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...") ) + def _emit_segment(seg: SegmentResult, index: int) -> None: + write_message(partial_segment_message(request_id, { + "index": index, + "text": seg.text, + "start_ms": seg.start_ms, + "end_ms": seg.end_ms, + "words": [{"word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence} for w in seg.words], + })) + transcription = self._transcribe_service.transcribe( request_id=request_id, file_path=file_path, @@ -89,6 +102,7 @@ class PipelineService: device=device, compute_type=compute_type, language=language, + on_segment=_emit_segment, ) if skip_diarization: @@ -140,6 +154,10 @@ class PipelineService: flush=True, ) + updates = [{"index": i, "speaker": seg.speaker} for i, seg in enumerate(result.segments) if seg.speaker] + if updates: + write_message(speaker_update_message(request_id, updates)) + write_message( progress_message(request_id, 100, "done", "Pipeline complete") ) diff --git a/python/voice_to_notes/services/transcribe.py b/python/voice_to_notes/services/transcribe.py index 2539cfc..7a0b67c 100644 --- a/python/voice_to_notes/services/transcribe.py +++ b/python/voice_to_notes/services/transcribe.py @@ -4,6 +4,7 @@ from __future__ import annotations import sys import time +from collections.abc import Callable from dataclasses import dataclass, field from typing import Any @@ -90,6 +91,7 @@ class TranscribeService: device: str = "cpu", compute_type: str = "int8", language: str | None = None, + on_segment: Callable[[SegmentResult, int], None] | None = None, ) -> TranscriptionResult: """Transcribe an audio file with word-level timestamps. @@ -145,6 +147,9 @@ class TranscribeService: ) ) + if on_segment: + on_segment(result.segments[-1], segment_count - 1) + # Send progress every few segments if segment_count % 5 == 0: write_message( diff --git a/src-tauri/src/commands/transcribe.rs b/src-tauri/src/commands/transcribe.rs index 9e2239a..eb3f515 100644 --- a/src-tauri/src/commands/transcribe.rs +++ b/src-tauri/src/commands/transcribe.rs @@ -1,4 +1,5 @@ use serde_json::{json, Value}; +use tauri::{AppHandle, Emitter}; use crate::sidecar::messages::IPCMessage; use crate::sidecar::sidecar; @@ -42,6 +43,7 @@ pub fn transcribe_file( /// Run the full transcription + diarization pipeline via the Python sidecar. #[tauri::command] pub fn run_pipeline( + app: AppHandle, file_path: String, model: Option, device: Option, @@ -71,7 +73,14 @@ pub fn run_pipeline( }), ); - let response = manager.send_and_receive(&msg)?; + let response = manager.send_and_receive_with_progress(&msg, |msg| { + let event_name = match msg.msg_type.as_str() { + "pipeline.segment" => "pipeline-segment", + "pipeline.speaker_update" => "pipeline-speaker-update", + _ => "pipeline-progress", + }; + let _ = app.emit(event_name, &msg.payload); + })?; if response.msg_type == "error" { return Err(format!( diff --git a/src-tauri/src/sidecar/mod.rs b/src-tauri/src/sidecar/mod.rs index dd60840..f542e89 100644 --- a/src-tauri/src/sidecar/mod.rs +++ b/src-tauri/src/sidecar/mod.rs @@ -165,6 +165,70 @@ impl SidecarManager { } } + /// Send a message and receive the response, calling a callback for intermediate messages. + /// Intermediate messages include progress, pipeline.segment, and pipeline.speaker_update. + pub fn send_and_receive_with_progress( + &self, + msg: &IPCMessage, + on_intermediate: F, + ) -> Result + where + F: Fn(&IPCMessage), + { + // Write to stdin + { + let mut stdin_guard = self.stdin.lock().map_err(|e| e.to_string())?; + if let Some(ref mut stdin) = *stdin_guard { + let json = serde_json::to_string(msg).map_err(|e| e.to_string())?; + stdin + .write_all(json.as_bytes()) + .map_err(|e| format!("Write error: {e}"))?; + stdin + .write_all(b"\n") + .map_err(|e| format!("Write error: {e}"))?; + stdin.flush().map_err(|e| format!("Flush error: {e}"))?; + } else { + return Err("Sidecar stdin not available".to_string()); + } + } + + // Read from stdout + { + let mut reader_guard = self.reader.lock().map_err(|e| e.to_string())?; + if let Some(ref mut reader) = *reader_guard { + let mut line = String::new(); + loop { + line.clear(); + let bytes_read = reader + .read_line(&mut line) + .map_err(|e| format!("Read error: {e}"))?; + if bytes_read == 0 { + return Err("Sidecar closed stdout".to_string()); + } + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let response: IPCMessage = serde_json::from_str(trimmed) + .map_err(|e| format!("Parse error: {e}"))?; + + // Forward intermediate messages via callback, return the final result/error + let is_intermediate = matches!( + response.msg_type.as_str(), + "progress" | "pipeline.segment" | "pipeline.speaker_update" + ); + if is_intermediate { + on_intermediate(&response); + } else { + return Ok(response); + } + } + } else { + Err("Sidecar stdout not available".to_string()) + } + } + } + /// Stop the sidecar process. pub fn stop(&self) -> Result<(), String> { // Drop stdin to signal EOF diff --git a/src/routes/+page.svelte b/src/routes/+page.svelte index 9b139a0..fc2db10 100644 --- a/src/routes/+page.svelte +++ b/src/routes/+page.svelte @@ -1,5 +1,6 @@