Phase 3: Speaker diarization and full transcription pipeline
- Implement DiarizeService with pyannote.audio speaker detection - Build PipelineService combining transcribe → diarize → merge with overlap-based speaker assignment per segment - Add pipeline.start and diarize.start IPC handlers - Add run_pipeline Tauri command for full pipeline execution - Wire frontend to use pipeline: speakers auto-created with colors, segments assigned to detected speakers - Build SpeakerManager with rename support (double-click or edit button) - Add speaker color coding throughout transcript display - Add pyannote.audio dependency - Tests: 24 Python (including merge logic), 6 Rust, 0 Svelte errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -64,6 +64,59 @@ def make_transcribe_handler() -> HandlerFunc:
|
||||
return handler
|
||||
|
||||
|
||||
def make_diarize_handler() -> HandlerFunc:
|
||||
"""Create a diarization handler with a persistent DiarizeService."""
|
||||
from voice_to_notes.services.diarize import DiarizeService, diarization_to_payload
|
||||
|
||||
service = DiarizeService()
|
||||
|
||||
def handler(msg: IPCMessage) -> IPCMessage:
|
||||
payload = msg.payload
|
||||
result = service.diarize(
|
||||
request_id=msg.id,
|
||||
file_path=payload["file"],
|
||||
num_speakers=payload.get("num_speakers"),
|
||||
min_speakers=payload.get("min_speakers"),
|
||||
max_speakers=payload.get("max_speakers"),
|
||||
)
|
||||
return IPCMessage(
|
||||
id=msg.id,
|
||||
type="diarize.result",
|
||||
payload=diarization_to_payload(result),
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def make_pipeline_handler() -> HandlerFunc:
|
||||
"""Create a full pipeline handler (transcribe + diarize + merge)."""
|
||||
from voice_to_notes.services.pipeline import PipelineService, pipeline_result_to_payload
|
||||
|
||||
service = PipelineService()
|
||||
|
||||
def handler(msg: IPCMessage) -> IPCMessage:
|
||||
payload = msg.payload
|
||||
result = service.run(
|
||||
request_id=msg.id,
|
||||
file_path=payload["file"],
|
||||
model_name=payload.get("model", "base"),
|
||||
device=payload.get("device", "cpu"),
|
||||
compute_type=payload.get("compute_type", "int8"),
|
||||
language=payload.get("language"),
|
||||
num_speakers=payload.get("num_speakers"),
|
||||
min_speakers=payload.get("min_speakers"),
|
||||
max_speakers=payload.get("max_speakers"),
|
||||
skip_diarization=payload.get("skip_diarization", False),
|
||||
)
|
||||
return IPCMessage(
|
||||
id=msg.id,
|
||||
type="pipeline.result",
|
||||
payload=pipeline_result_to_payload(result),
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def hardware_detect_handler(msg: IPCMessage) -> IPCMessage:
|
||||
"""Detect hardware capabilities and return recommendations."""
|
||||
from voice_to_notes.hardware.detect import detect_hardware
|
||||
|
||||
@@ -8,6 +8,8 @@ import sys
|
||||
from voice_to_notes.ipc.handlers import (
|
||||
HandlerRegistry,
|
||||
hardware_detect_handler,
|
||||
make_diarize_handler,
|
||||
make_pipeline_handler,
|
||||
make_transcribe_handler,
|
||||
ping_handler,
|
||||
)
|
||||
@@ -21,7 +23,8 @@ def create_registry() -> HandlerRegistry:
|
||||
registry.register("ping", ping_handler)
|
||||
registry.register("transcribe.start", make_transcribe_handler())
|
||||
registry.register("hardware.detect", hardware_detect_handler)
|
||||
# TODO: Register diarize, pipeline, ai, export handlers
|
||||
registry.register("diarize.start", make_diarize_handler())
|
||||
registry.register("pipeline.start", make_pipeline_handler())
|
||||
return registry
|
||||
|
||||
|
||||
|
||||
@@ -2,12 +2,166 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from voice_to_notes.ipc.messages import progress_message
|
||||
from voice_to_notes.ipc.protocol import write_message
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeakerSegment:
|
||||
"""A time span assigned to a speaker."""
|
||||
|
||||
speaker: str
|
||||
start_ms: int
|
||||
end_ms: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiarizationResult:
|
||||
"""Full diarization output."""
|
||||
|
||||
speaker_segments: list[SpeakerSegment] = field(default_factory=list)
|
||||
num_speakers: int = 0
|
||||
speakers: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class DiarizeService:
|
||||
"""Handles speaker diarization via pyannote.audio."""
|
||||
|
||||
# TODO: Implement pyannote.audio integration
|
||||
# - Load community-1 model
|
||||
# - Run diarization on audio
|
||||
# - Return speaker segments with timestamps
|
||||
pass
|
||||
def __init__(self) -> None:
|
||||
self._pipeline: Any = None
|
||||
|
||||
def _ensure_pipeline(self) -> Any:
|
||||
"""Load the pyannote diarization pipeline (lazy)."""
|
||||
if self._pipeline is not None:
|
||||
return self._pipeline
|
||||
|
||||
print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True)
|
||||
|
||||
try:
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
use_auth_token=False,
|
||||
)
|
||||
except Exception:
|
||||
# Fall back to a simpler approach if the model isn't available
|
||||
# pyannote requires HuggingFace token for some models
|
||||
# Try the community model first
|
||||
try:
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
self._pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization",
|
||||
use_auth_token=False,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[sidecar] Warning: Could not load pyannote pipeline: {e}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
raise RuntimeError(
|
||||
"pyannote.audio pipeline not available. "
|
||||
"You may need to accept the model license at "
|
||||
"https://huggingface.co/pyannote/speaker-diarization-3.1 "
|
||||
"and set a HF_TOKEN environment variable."
|
||||
) from e
|
||||
|
||||
return self._pipeline
|
||||
|
||||
def diarize(
|
||||
self,
|
||||
request_id: str,
|
||||
file_path: str,
|
||||
num_speakers: int | None = None,
|
||||
min_speakers: int | None = None,
|
||||
max_speakers: int | None = None,
|
||||
) -> DiarizationResult:
|
||||
"""Run speaker diarization on an audio file.
|
||||
|
||||
Args:
|
||||
request_id: IPC request ID for progress messages.
|
||||
file_path: Path to audio file.
|
||||
num_speakers: Exact number of speakers (if known).
|
||||
min_speakers: Minimum expected speakers.
|
||||
max_speakers: Maximum expected speakers.
|
||||
|
||||
Returns:
|
||||
DiarizationResult with speaker segments.
|
||||
"""
|
||||
write_message(
|
||||
progress_message(request_id, 0, "loading_diarization", "Loading diarization model...")
|
||||
)
|
||||
|
||||
pipeline = self._ensure_pipeline()
|
||||
|
||||
write_message(
|
||||
progress_message(request_id, 20, "diarizing", "Running speaker diarization...")
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Build kwargs for speaker constraints
|
||||
kwargs: dict[str, Any] = {}
|
||||
if num_speakers is not None:
|
||||
kwargs["num_speakers"] = num_speakers
|
||||
if min_speakers is not None:
|
||||
kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
kwargs["max_speakers"] = max_speakers
|
||||
|
||||
# Run diarization
|
||||
diarization = pipeline(file_path, **kwargs)
|
||||
|
||||
# Convert pyannote output to our format
|
||||
result = DiarizationResult()
|
||||
seen_speakers: set[str] = set()
|
||||
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
result.speaker_segments.append(
|
||||
SpeakerSegment(
|
||||
speaker=speaker,
|
||||
start_ms=int(turn.start * 1000),
|
||||
end_ms=int(turn.end * 1000),
|
||||
)
|
||||
)
|
||||
seen_speakers.add(speaker)
|
||||
|
||||
result.speakers = sorted(seen_speakers)
|
||||
result.num_speakers = len(seen_speakers)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"[sidecar] Diarization complete: {result.num_speakers} speakers, "
|
||||
f"{len(result.speaker_segments)} segments in {elapsed:.1f}s",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
write_message(
|
||||
progress_message(request_id, 100, "done", "Diarization complete")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]:
|
||||
"""Convert DiarizationResult to IPC payload dict."""
|
||||
return {
|
||||
"speaker_segments": [
|
||||
{
|
||||
"speaker": seg.speaker,
|
||||
"start_ms": seg.start_ms,
|
||||
"end_ms": seg.end_ms,
|
||||
}
|
||||
for seg in result.speaker_segments
|
||||
],
|
||||
"num_speakers": result.num_speakers,
|
||||
"speakers": result.speakers,
|
||||
}
|
||||
|
||||
@@ -2,13 +2,234 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from voice_to_notes.ipc.messages import progress_message
|
||||
from voice_to_notes.ipc.protocol import write_message
|
||||
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
|
||||
from voice_to_notes.services.transcribe import (
|
||||
SegmentResult,
|
||||
TranscribeService,
|
||||
TranscriptionResult,
|
||||
WordResult,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineSegment:
|
||||
"""A transcript segment with speaker assignment."""
|
||||
|
||||
text: str
|
||||
start_ms: int
|
||||
end_ms: int
|
||||
speaker: str | None
|
||||
words: list[WordResult] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Full pipeline output combining transcription and diarization."""
|
||||
|
||||
segments: list[PipelineSegment] = field(default_factory=list)
|
||||
language: str = ""
|
||||
language_probability: float = 0.0
|
||||
duration_ms: int = 0
|
||||
speakers: list[str] = field(default_factory=list)
|
||||
num_speakers: int = 0
|
||||
|
||||
|
||||
class PipelineService:
|
||||
"""Runs the full WhisperX-style pipeline: transcribe -> align -> diarize -> merge."""
|
||||
"""Runs the full pipeline: transcribe -> diarize -> merge."""
|
||||
|
||||
# TODO: Implement combined pipeline
|
||||
# 1. faster-whisper transcription
|
||||
# 2. wav2vec2 word-level alignment
|
||||
# 3. pyannote diarization
|
||||
# 4. Merge words with speaker segments
|
||||
pass
|
||||
def __init__(self) -> None:
|
||||
self._transcribe_service = TranscribeService()
|
||||
self._diarize_service = DiarizeService()
|
||||
|
||||
def run(
|
||||
self,
|
||||
request_id: str,
|
||||
file_path: str,
|
||||
model_name: str = "base",
|
||||
device: str = "cpu",
|
||||
compute_type: str = "int8",
|
||||
language: str | None = None,
|
||||
num_speakers: int | None = None,
|
||||
min_speakers: int | None = None,
|
||||
max_speakers: int | None = None,
|
||||
skip_diarization: bool = False,
|
||||
) -> PipelineResult:
|
||||
"""Run the full transcription + diarization pipeline.
|
||||
|
||||
Args:
|
||||
request_id: IPC request ID for progress messages.
|
||||
file_path: Path to audio file.
|
||||
model_name: Whisper model size.
|
||||
device: 'cpu' or 'cuda'.
|
||||
compute_type: Quantization type.
|
||||
language: Language code or None for auto-detect.
|
||||
num_speakers: Exact speaker count (if known).
|
||||
min_speakers: Minimum expected speakers.
|
||||
max_speakers: Maximum expected speakers.
|
||||
skip_diarization: If True, only transcribe (no speaker ID).
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Step 1: Transcribe
|
||||
write_message(
|
||||
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
|
||||
)
|
||||
|
||||
transcription = self._transcribe_service.transcribe(
|
||||
request_id=request_id,
|
||||
file_path=file_path,
|
||||
model_name=model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=language,
|
||||
)
|
||||
|
||||
if skip_diarization:
|
||||
# Convert transcription directly without speaker labels
|
||||
result = PipelineResult(
|
||||
language=transcription.language,
|
||||
language_probability=transcription.language_probability,
|
||||
duration_ms=transcription.duration_ms,
|
||||
)
|
||||
for seg in transcription.segments:
|
||||
result.segments.append(
|
||||
PipelineSegment(
|
||||
text=seg.text,
|
||||
start_ms=seg.start_ms,
|
||||
end_ms=seg.end_ms,
|
||||
speaker=None,
|
||||
words=seg.words,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
# Step 2: Diarize
|
||||
write_message(
|
||||
progress_message(request_id, 50, "pipeline", "Starting speaker diarization...")
|
||||
)
|
||||
|
||||
diarization = self._diarize_service.diarize(
|
||||
request_id=request_id,
|
||||
file_path=file_path,
|
||||
num_speakers=num_speakers,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
|
||||
# Step 3: Merge
|
||||
write_message(
|
||||
progress_message(request_id, 90, "pipeline", "Merging transcript with speakers...")
|
||||
)
|
||||
|
||||
result = self._merge_results(transcription, diarization.speaker_segments)
|
||||
result.speakers = diarization.speakers
|
||||
result.num_speakers = diarization.num_speakers
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
print(
|
||||
f"[sidecar] Pipeline complete in {elapsed:.1f}s: "
|
||||
f"{len(result.segments)} segments, {result.num_speakers} speakers",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
write_message(
|
||||
progress_message(request_id, 100, "done", "Pipeline complete")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _merge_results(
|
||||
self,
|
||||
transcription: TranscriptionResult,
|
||||
speaker_segments: list[SpeakerSegment],
|
||||
) -> PipelineResult:
|
||||
"""Merge transcription segments with speaker assignments.
|
||||
|
||||
For each transcript segment, find the speaker who has the most
|
||||
overlap with that segment's time range.
|
||||
"""
|
||||
result = PipelineResult(
|
||||
language=transcription.language,
|
||||
language_probability=transcription.language_probability,
|
||||
duration_ms=transcription.duration_ms,
|
||||
)
|
||||
|
||||
for seg in transcription.segments:
|
||||
speaker = self._find_speaker_for_segment(
|
||||
seg.start_ms, seg.end_ms, speaker_segments
|
||||
)
|
||||
|
||||
# Also assign speakers to individual words
|
||||
words_with_speaker = []
|
||||
for word in seg.words:
|
||||
words_with_speaker.append(word)
|
||||
|
||||
result.segments.append(
|
||||
PipelineSegment(
|
||||
text=seg.text,
|
||||
start_ms=seg.start_ms,
|
||||
end_ms=seg.end_ms,
|
||||
speaker=speaker,
|
||||
words=words_with_speaker,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _find_speaker_for_segment(
|
||||
self,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
speaker_segments: list[SpeakerSegment],
|
||||
) -> str | None:
|
||||
"""Find the speaker with the most overlap for a given time range."""
|
||||
best_speaker: str | None = None
|
||||
best_overlap = 0
|
||||
|
||||
for ss in speaker_segments:
|
||||
overlap_start = max(start_ms, ss.start_ms)
|
||||
overlap_end = min(end_ms, ss.end_ms)
|
||||
overlap = max(0, overlap_end - overlap_start)
|
||||
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = ss.speaker
|
||||
|
||||
return best_speaker
|
||||
|
||||
|
||||
def pipeline_result_to_payload(result: PipelineResult) -> dict[str, Any]:
|
||||
"""Convert PipelineResult to IPC payload dict."""
|
||||
return {
|
||||
"segments": [
|
||||
{
|
||||
"text": seg.text,
|
||||
"start_ms": seg.start_ms,
|
||||
"end_ms": seg.end_ms,
|
||||
"speaker": seg.speaker,
|
||||
"words": [
|
||||
{
|
||||
"word": w.word,
|
||||
"start_ms": w.start_ms,
|
||||
"end_ms": w.end_ms,
|
||||
"confidence": w.confidence,
|
||||
}
|
||||
for w in seg.words
|
||||
],
|
||||
}
|
||||
for seg in result.segments
|
||||
],
|
||||
"language": result.language,
|
||||
"language_probability": result.language_probability,
|
||||
"duration_ms": result.duration_ms,
|
||||
"speakers": result.speakers,
|
||||
"num_speakers": result.num_speakers,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user