Files
voice-to-notes/python/voice_to_notes/services/diarize.py
Claude 879a1f3fd6
All checks were successful
Build Sidecars / Bump sidecar version and tag (push) Successful in 7s
Release / Bump version and tag (push) Successful in 5s
Build Sidecars / Build Sidecar (macOS) (push) Successful in 4m32s
Release / Build App (macOS) (push) Successful in 1m16s
Build Sidecars / Build Sidecar (Linux) (push) Successful in 16m28s
Release / Build App (Linux) (push) Successful in 4m26s
Build Sidecars / Build Sidecar (Windows) (push) Successful in 33m5s
Release / Build App (Windows) (push) Successful in 3m29s
Fix diarization tensor mismatch + fix sidecar build triggers
Diarization: Audio.crop patch now pads short segments with zeros to
match the expected duration. pyannote batches embeddings with vstack
which requires uniform tensor sizes — the last segment of a file can
be shorter than the 10s window.

CI: Reordered sidecar workflow to check for python/ changes FIRST,
before bumping version or configuring git. All subsequent steps are
gated on has_changes. This prevents unnecessary version bumps and
build runs when only app code changes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-22 18:30:43 -07:00

352 lines
12 KiB
Python

"""Diarization service — pyannote.audio speaker identification."""
from __future__ import annotations
import os
import subprocess
import sys
import tempfile
import threading
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Disable pyannote telemetry — it has a bug in v4.0.4 where
# np.isfinite(None) crashes when max_speakers is not set.
os.environ.setdefault("PYANNOTE_METRICS_ENABLED", "false")
from voice_to_notes.utils.ffmpeg import get_ffmpeg_path
from voice_to_notes.ipc.messages import progress_message
from voice_to_notes.ipc.protocol import write_message
_patched = False
def _patch_pyannote_audio() -> None:
"""Monkey-patch pyannote.audio.core.io.Audio to use torchaudio.
pyannote.audio has a bug where AudioDecoder (from torchcodec) is used
unconditionally even when torchcodec is not installed, causing NameError.
This replaces the Audio.__call__ method with a torchaudio-based version.
"""
global _patched
if _patched:
return
_patched = True
try:
import numpy as np
import soundfile as sf
import torch
from pyannote.audio.core.io import Audio
def _sf_load(audio_path: str) -> tuple:
"""Load audio via soundfile, return (channels, samples) tensor + sample_rate."""
data, sample_rate = sf.read(str(audio_path), dtype="float32")
waveform = torch.from_numpy(np.array(data))
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
else:
waveform = waveform.T
return waveform, sample_rate
def _soundfile_call(self, file: dict) -> tuple:
"""Replacement for Audio.__call__."""
return _sf_load(file["audio"])
def _soundfile_crop(self, file: dict, segment, **kwargs) -> tuple:
"""Replacement for Audio.crop — load full file then slice.
Pads short segments with zeros to match the expected duration,
which pyannote requires for batched embedding extraction.
"""
duration = kwargs.get("duration", None)
waveform, sample_rate = _sf_load(file["audio"])
# Convert segment (seconds) to sample indices
start_sample = int(segment.start * sample_rate)
end_sample = int(segment.end * sample_rate)
# Clamp to bounds
start_sample = max(0, start_sample)
end_sample = min(waveform.shape[-1], end_sample)
cropped = waveform[:, start_sample:end_sample]
# Pad to expected duration if needed (pyannote batches require uniform size)
if duration is not None:
expected_samples = int(duration * sample_rate)
else:
expected_samples = int((segment.end - segment.start) * sample_rate)
if cropped.shape[-1] < expected_samples:
pad = torch.zeros(cropped.shape[0], expected_samples - cropped.shape[-1])
cropped = torch.cat([cropped, pad], dim=-1)
return cropped, sample_rate
Audio.__call__ = _soundfile_call # type: ignore[assignment]
Audio.crop = _soundfile_crop # type: ignore[assignment]
print("[sidecar] Patched pyannote Audio to use soundfile", file=sys.stderr, flush=True)
except Exception as e:
print(f"[sidecar] Warning: Could not patch pyannote Audio: {e}", file=sys.stderr, flush=True)
def _ensure_wav(file_path: str) -> tuple[str, str | None]:
"""Convert audio to 16kHz mono WAV if needed.
pyannote.audio v4.0.4 has a bug where its AudioDecoder returns
duration=None for some formats (FLAC, etc.), causing crashes.
Converting to WAV ensures the duration header is always present.
Returns:
(path_to_use, temp_path_or_None)
If conversion was needed, temp_path is the WAV file to clean up.
"""
ext = Path(file_path).suffix.lower()
if ext == ".wav":
return file_path, None
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp.close()
try:
subprocess.run(
[
get_ffmpeg_path(), "-y", "-i", file_path,
"-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
tmp.name,
],
check=True,
capture_output=True,
)
print(
f"[sidecar] Converted {ext} to WAV for diarization",
file=sys.stderr,
flush=True,
)
return tmp.name, tmp.name
except (subprocess.CalledProcessError, FileNotFoundError) as e:
# ffmpeg not available or failed — try original file and hope for the best
print(
f"[sidecar] WAV conversion failed ({e}), using original file",
file=sys.stderr,
flush=True,
)
os.unlink(tmp.name)
return file_path, None
@dataclass
class SpeakerSegment:
"""A time span assigned to a speaker."""
speaker: str
start_ms: int
end_ms: int
@dataclass
class DiarizationResult:
"""Full diarization output."""
speaker_segments: list[SpeakerSegment] = field(default_factory=list)
num_speakers: int = 0
speakers: list[str] = field(default_factory=list)
class DiarizeService:
"""Handles speaker diarization via pyannote.audio."""
def __init__(self) -> None:
self._pipeline: Any = None
def _ensure_pipeline(self, hf_token: str | None = None) -> Any:
"""Load the pyannote diarization pipeline (lazy)."""
if self._pipeline is not None:
return self._pipeline
print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True)
# Use token from argument, fall back to environment variable
if not hf_token:
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
# Persist token globally so ALL huggingface_hub sub-downloads use auth.
# Pyannote has internal dependencies that don't forward the token= param.
if hf_token:
os.environ["HF_TOKEN"] = hf_token
import huggingface_hub
huggingface_hub.login(token=hf_token, add_to_git_credential=False)
models = [
"pyannote/speaker-diarization-3.1",
"pyannote/speaker-diarization",
]
last_error: Exception | None = None
_patch_pyannote_audio()
for model_name in models:
try:
from pyannote.audio import Pipeline
self._pipeline = Pipeline.from_pretrained(model_name, token=hf_token)
print(f"[sidecar] Loaded diarization model: {model_name}", file=sys.stderr, flush=True)
# Move pipeline to GPU if available
try:
import torch
if torch.cuda.is_available():
self._pipeline = self._pipeline.to(torch.device("cuda"))
print(f"[sidecar] Diarization pipeline moved to GPU", file=sys.stderr, flush=True)
except Exception as e:
print(f"[sidecar] GPU not available for diarization: {e}", file=sys.stderr, flush=True)
return self._pipeline
except Exception as e:
last_error = e
print(
f"[sidecar] Warning: Could not load {model_name}: {e}",
file=sys.stderr,
flush=True,
)
raise RuntimeError(
"pyannote.audio pipeline not available. "
"You may need to accept the model license at "
"https://huggingface.co/pyannote/speaker-diarization-3.1 "
"and set a HF_TOKEN environment variable."
) from last_error
def diarize(
self,
request_id: str,
file_path: str,
num_speakers: int | None = None,
min_speakers: int | None = None,
max_speakers: int | None = None,
hf_token: str | None = None,
audio_duration_sec: float | None = None,
) -> DiarizationResult:
"""Run speaker diarization on an audio file.
Args:
request_id: IPC request ID for progress messages.
file_path: Path to audio file.
num_speakers: Exact number of speakers (if known).
min_speakers: Minimum expected speakers.
max_speakers: Maximum expected speakers.
Returns:
DiarizationResult with speaker segments.
"""
write_message(
progress_message(request_id, 0, "loading_diarization", "Loading diarization model...")
)
pipeline = self._ensure_pipeline(hf_token=hf_token)
write_message(
progress_message(request_id, 20, "diarizing", "Running speaker diarization...")
)
start_time = time.time()
# Build kwargs for speaker constraints
kwargs: dict[str, Any] = {}
if num_speakers is not None:
kwargs["num_speakers"] = num_speakers
if min_speakers is not None:
kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
kwargs["max_speakers"] = max_speakers
# Convert to WAV to work around pyannote v4.0.4 duration bug
audio_path, temp_wav = _ensure_wav(file_path)
print(
f"[sidecar] Running diarization on {audio_path} with kwargs: {kwargs}",
file=sys.stderr,
flush=True,
)
# Run diarization in background thread for progress reporting
result_holder: list = [None]
error_holder: list[Exception | None] = [None]
done_event = threading.Event()
def _run():
try:
result_holder[0] = pipeline(audio_path, **kwargs)
except Exception as e:
error_holder[0] = e
finally:
done_event.set()
thread = threading.Thread(target=_run, daemon=True)
thread.start()
elapsed = 0.0
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
while not done_event.wait(timeout=2.0):
elapsed += 2.0
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
write_message(progress_message(
request_id, pct, "diarizing",
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
thread.join()
# Clean up temp file
if temp_wav:
os.unlink(temp_wav)
if error_holder[0] is not None:
raise error_holder[0]
raw_result = result_holder[0]
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
if hasattr(raw_result, "speaker_diarization"):
diarization = raw_result.speaker_diarization
else:
diarization = raw_result
# Convert pyannote output to our format
result = DiarizationResult()
seen_speakers: set[str] = set()
for turn, _, speaker in diarization.itertracks(yield_label=True):
result.speaker_segments.append(
SpeakerSegment(
speaker=speaker,
start_ms=int(turn.start * 1000),
end_ms=int(turn.end * 1000),
)
)
seen_speakers.add(speaker)
result.speakers = sorted(seen_speakers)
result.num_speakers = len(seen_speakers)
elapsed = time.time() - start_time
print(
f"[sidecar] Diarization complete: {result.num_speakers} speakers, "
f"{len(result.speaker_segments)} segments in {elapsed:.1f}s",
file=sys.stderr,
flush=True,
)
write_message(
progress_message(request_id, 100, "done", "Diarization complete")
)
return result
def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]:
"""Convert DiarizationResult to IPC payload dict."""
return {
"speaker_segments": [
{
"speaker": seg.speaker,
"start_ms": seg.start_ms,
"end_ms": seg.end_ms,
}
for seg in result.speaker_segments
],
"num_speakers": result.num_speakers,
"speakers": result.speakers,
}