- Implement DiarizeService with pyannote.audio speaker detection - Build PipelineService combining transcribe → diarize → merge with overlap-based speaker assignment per segment - Add pipeline.start and diarize.start IPC handlers - Add run_pipeline Tauri command for full pipeline execution - Wire frontend to use pipeline: speakers auto-created with colors, segments assigned to detected speakers - Build SpeakerManager with rename support (double-click or edit button) - Add speaker color coding throughout transcript display - Add pyannote.audio dependency - Tests: 24 Python (including merge logic), 6 Rust, 0 Svelte errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
236 lines
7.3 KiB
Python
236 lines
7.3 KiB
Python
"""Combined transcription + diarization pipeline."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from voice_to_notes.ipc.messages import progress_message
|
|
from voice_to_notes.ipc.protocol import write_message
|
|
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
|
|
from voice_to_notes.services.transcribe import (
|
|
SegmentResult,
|
|
TranscribeService,
|
|
TranscriptionResult,
|
|
WordResult,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PipelineSegment:
|
|
"""A transcript segment with speaker assignment."""
|
|
|
|
text: str
|
|
start_ms: int
|
|
end_ms: int
|
|
speaker: str | None
|
|
words: list[WordResult] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class PipelineResult:
|
|
"""Full pipeline output combining transcription and diarization."""
|
|
|
|
segments: list[PipelineSegment] = field(default_factory=list)
|
|
language: str = ""
|
|
language_probability: float = 0.0
|
|
duration_ms: int = 0
|
|
speakers: list[str] = field(default_factory=list)
|
|
num_speakers: int = 0
|
|
|
|
|
|
class PipelineService:
|
|
"""Runs the full pipeline: transcribe -> diarize -> merge."""
|
|
|
|
def __init__(self) -> None:
|
|
self._transcribe_service = TranscribeService()
|
|
self._diarize_service = DiarizeService()
|
|
|
|
def run(
|
|
self,
|
|
request_id: str,
|
|
file_path: str,
|
|
model_name: str = "base",
|
|
device: str = "cpu",
|
|
compute_type: str = "int8",
|
|
language: str | None = None,
|
|
num_speakers: int | None = None,
|
|
min_speakers: int | None = None,
|
|
max_speakers: int | None = None,
|
|
skip_diarization: bool = False,
|
|
) -> PipelineResult:
|
|
"""Run the full transcription + diarization pipeline.
|
|
|
|
Args:
|
|
request_id: IPC request ID for progress messages.
|
|
file_path: Path to audio file.
|
|
model_name: Whisper model size.
|
|
device: 'cpu' or 'cuda'.
|
|
compute_type: Quantization type.
|
|
language: Language code or None for auto-detect.
|
|
num_speakers: Exact speaker count (if known).
|
|
min_speakers: Minimum expected speakers.
|
|
max_speakers: Maximum expected speakers.
|
|
skip_diarization: If True, only transcribe (no speaker ID).
|
|
"""
|
|
start_time = time.time()
|
|
|
|
# Step 1: Transcribe
|
|
write_message(
|
|
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
|
|
)
|
|
|
|
transcription = self._transcribe_service.transcribe(
|
|
request_id=request_id,
|
|
file_path=file_path,
|
|
model_name=model_name,
|
|
device=device,
|
|
compute_type=compute_type,
|
|
language=language,
|
|
)
|
|
|
|
if skip_diarization:
|
|
# Convert transcription directly without speaker labels
|
|
result = PipelineResult(
|
|
language=transcription.language,
|
|
language_probability=transcription.language_probability,
|
|
duration_ms=transcription.duration_ms,
|
|
)
|
|
for seg in transcription.segments:
|
|
result.segments.append(
|
|
PipelineSegment(
|
|
text=seg.text,
|
|
start_ms=seg.start_ms,
|
|
end_ms=seg.end_ms,
|
|
speaker=None,
|
|
words=seg.words,
|
|
)
|
|
)
|
|
return result
|
|
|
|
# Step 2: Diarize
|
|
write_message(
|
|
progress_message(request_id, 50, "pipeline", "Starting speaker diarization...")
|
|
)
|
|
|
|
diarization = self._diarize_service.diarize(
|
|
request_id=request_id,
|
|
file_path=file_path,
|
|
num_speakers=num_speakers,
|
|
min_speakers=min_speakers,
|
|
max_speakers=max_speakers,
|
|
)
|
|
|
|
# Step 3: Merge
|
|
write_message(
|
|
progress_message(request_id, 90, "pipeline", "Merging transcript with speakers...")
|
|
)
|
|
|
|
result = self._merge_results(transcription, diarization.speaker_segments)
|
|
result.speakers = diarization.speakers
|
|
result.num_speakers = diarization.num_speakers
|
|
|
|
elapsed = time.time() - start_time
|
|
print(
|
|
f"[sidecar] Pipeline complete in {elapsed:.1f}s: "
|
|
f"{len(result.segments)} segments, {result.num_speakers} speakers",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
write_message(
|
|
progress_message(request_id, 100, "done", "Pipeline complete")
|
|
)
|
|
|
|
return result
|
|
|
|
def _merge_results(
|
|
self,
|
|
transcription: TranscriptionResult,
|
|
speaker_segments: list[SpeakerSegment],
|
|
) -> PipelineResult:
|
|
"""Merge transcription segments with speaker assignments.
|
|
|
|
For each transcript segment, find the speaker who has the most
|
|
overlap with that segment's time range.
|
|
"""
|
|
result = PipelineResult(
|
|
language=transcription.language,
|
|
language_probability=transcription.language_probability,
|
|
duration_ms=transcription.duration_ms,
|
|
)
|
|
|
|
for seg in transcription.segments:
|
|
speaker = self._find_speaker_for_segment(
|
|
seg.start_ms, seg.end_ms, speaker_segments
|
|
)
|
|
|
|
# Also assign speakers to individual words
|
|
words_with_speaker = []
|
|
for word in seg.words:
|
|
words_with_speaker.append(word)
|
|
|
|
result.segments.append(
|
|
PipelineSegment(
|
|
text=seg.text,
|
|
start_ms=seg.start_ms,
|
|
end_ms=seg.end_ms,
|
|
speaker=speaker,
|
|
words=words_with_speaker,
|
|
)
|
|
)
|
|
|
|
return result
|
|
|
|
def _find_speaker_for_segment(
|
|
self,
|
|
start_ms: int,
|
|
end_ms: int,
|
|
speaker_segments: list[SpeakerSegment],
|
|
) -> str | None:
|
|
"""Find the speaker with the most overlap for a given time range."""
|
|
best_speaker: str | None = None
|
|
best_overlap = 0
|
|
|
|
for ss in speaker_segments:
|
|
overlap_start = max(start_ms, ss.start_ms)
|
|
overlap_end = min(end_ms, ss.end_ms)
|
|
overlap = max(0, overlap_end - overlap_start)
|
|
|
|
if overlap > best_overlap:
|
|
best_overlap = overlap
|
|
best_speaker = ss.speaker
|
|
|
|
return best_speaker
|
|
|
|
|
|
def pipeline_result_to_payload(result: PipelineResult) -> dict[str, Any]:
|
|
"""Convert PipelineResult to IPC payload dict."""
|
|
return {
|
|
"segments": [
|
|
{
|
|
"text": seg.text,
|
|
"start_ms": seg.start_ms,
|
|
"end_ms": seg.end_ms,
|
|
"speaker": seg.speaker,
|
|
"words": [
|
|
{
|
|
"word": w.word,
|
|
"start_ms": w.start_ms,
|
|
"end_ms": w.end_ms,
|
|
"confidence": w.confidence,
|
|
}
|
|
for w in seg.words
|
|
],
|
|
}
|
|
for seg in result.segments
|
|
],
|
|
"language": result.language,
|
|
"language_probability": result.language_probability,
|
|
"duration_ms": result.duration_ms,
|
|
"speakers": result.speakers,
|
|
"num_speakers": result.num_speakers,
|
|
}
|