"""Combined transcription + diarization pipeline.""" from __future__ import annotations import sys import time from dataclasses import dataclass, field from typing import Any from voice_to_notes.ipc.messages import progress_message from voice_to_notes.ipc.protocol import write_message from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment from voice_to_notes.services.transcribe import ( SegmentResult, TranscribeService, TranscriptionResult, WordResult, ) @dataclass class PipelineSegment: """A transcript segment with speaker assignment.""" text: str start_ms: int end_ms: int speaker: str | None words: list[WordResult] = field(default_factory=list) @dataclass class PipelineResult: """Full pipeline output combining transcription and diarization.""" segments: list[PipelineSegment] = field(default_factory=list) language: str = "" language_probability: float = 0.0 duration_ms: int = 0 speakers: list[str] = field(default_factory=list) num_speakers: int = 0 class PipelineService: """Runs the full pipeline: transcribe -> diarize -> merge.""" def __init__(self) -> None: self._transcribe_service = TranscribeService() self._diarize_service = DiarizeService() def run( self, request_id: str, file_path: str, model_name: str = "base", device: str = "cpu", compute_type: str = "int8", language: str | None = None, num_speakers: int | None = None, min_speakers: int | None = None, max_speakers: int | None = None, skip_diarization: bool = False, ) -> PipelineResult: """Run the full transcription + diarization pipeline. Args: request_id: IPC request ID for progress messages. file_path: Path to audio file. model_name: Whisper model size. device: 'cpu' or 'cuda'. compute_type: Quantization type. language: Language code or None for auto-detect. num_speakers: Exact speaker count (if known). min_speakers: Minimum expected speakers. max_speakers: Maximum expected speakers. skip_diarization: If True, only transcribe (no speaker ID). """ start_time = time.time() # Step 1: Transcribe write_message( progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...") ) transcription = self._transcribe_service.transcribe( request_id=request_id, file_path=file_path, model_name=model_name, device=device, compute_type=compute_type, language=language, ) if skip_diarization: # Convert transcription directly without speaker labels result = PipelineResult( language=transcription.language, language_probability=transcription.language_probability, duration_ms=transcription.duration_ms, ) for seg in transcription.segments: result.segments.append( PipelineSegment( text=seg.text, start_ms=seg.start_ms, end_ms=seg.end_ms, speaker=None, words=seg.words, ) ) return result # Step 2: Diarize write_message( progress_message(request_id, 50, "pipeline", "Starting speaker diarization...") ) diarization = self._diarize_service.diarize( request_id=request_id, file_path=file_path, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, ) # Step 3: Merge write_message( progress_message(request_id, 90, "pipeline", "Merging transcript with speakers...") ) result = self._merge_results(transcription, diarization.speaker_segments) result.speakers = diarization.speakers result.num_speakers = diarization.num_speakers elapsed = time.time() - start_time print( f"[sidecar] Pipeline complete in {elapsed:.1f}s: " f"{len(result.segments)} segments, {result.num_speakers} speakers", file=sys.stderr, flush=True, ) write_message( progress_message(request_id, 100, "done", "Pipeline complete") ) return result def _merge_results( self, transcription: TranscriptionResult, speaker_segments: list[SpeakerSegment], ) -> PipelineResult: """Merge transcription segments with speaker assignments. For each transcript segment, find the speaker who has the most overlap with that segment's time range. """ result = PipelineResult( language=transcription.language, language_probability=transcription.language_probability, duration_ms=transcription.duration_ms, ) for seg in transcription.segments: speaker = self._find_speaker_for_segment( seg.start_ms, seg.end_ms, speaker_segments ) # Also assign speakers to individual words words_with_speaker = [] for word in seg.words: words_with_speaker.append(word) result.segments.append( PipelineSegment( text=seg.text, start_ms=seg.start_ms, end_ms=seg.end_ms, speaker=speaker, words=words_with_speaker, ) ) return result def _find_speaker_for_segment( self, start_ms: int, end_ms: int, speaker_segments: list[SpeakerSegment], ) -> str | None: """Find the speaker with the most overlap for a given time range.""" best_speaker: str | None = None best_overlap = 0 for ss in speaker_segments: overlap_start = max(start_ms, ss.start_ms) overlap_end = min(end_ms, ss.end_ms) overlap = max(0, overlap_end - overlap_start) if overlap > best_overlap: best_overlap = overlap best_speaker = ss.speaker return best_speaker def pipeline_result_to_payload(result: PipelineResult) -> dict[str, Any]: """Convert PipelineResult to IPC payload dict.""" return { "segments": [ { "text": seg.text, "start_ms": seg.start_ms, "end_ms": seg.end_ms, "speaker": seg.speaker, "words": [ { "word": w.word, "start_ms": w.start_ms, "end_ms": w.end_ms, "confidence": w.confidence, } for w in seg.words ], } for seg in result.segments ], "language": result.language, "language_probability": result.language_probability, "duration_ms": result.duration_ms, "speakers": result.speakers, "num_speakers": result.num_speakers, }