"""Diarization service — pyannote.audio speaker identification.""" from __future__ import annotations import os import subprocess import sys import tempfile import threading import time from dataclasses import dataclass, field from pathlib import Path from typing import Any # Disable pyannote telemetry — it has a bug in v4.0.4 where # np.isfinite(None) crashes when max_speakers is not set. os.environ.setdefault("PYANNOTE_METRICS_ENABLED", "false") from voice_to_notes.utils.ffmpeg import get_ffmpeg_path from voice_to_notes.ipc.messages import progress_message from voice_to_notes.ipc.protocol import write_message _patched = False def _patch_pyannote_audio() -> None: """Monkey-patch pyannote.audio.core.io.Audio to use torchaudio. pyannote.audio has a bug where AudioDecoder (from torchcodec) is used unconditionally even when torchcodec is not installed, causing NameError. This replaces the Audio.__call__ method with a torchaudio-based version. """ global _patched if _patched: return _patched = True try: import numpy as np import soundfile as sf import torch from pyannote.audio.core.io import Audio def _sf_load(audio_path: str) -> tuple: """Load audio via soundfile, return (channels, samples) tensor + sample_rate.""" data, sample_rate = sf.read(str(audio_path), dtype="float32") waveform = torch.from_numpy(np.array(data)) if waveform.ndim == 1: waveform = waveform.unsqueeze(0) else: waveform = waveform.T return waveform, sample_rate def _soundfile_call(self, file: dict) -> tuple: """Replacement for Audio.__call__.""" return _sf_load(file["audio"]) def _soundfile_crop(self, file: dict, segment, **kwargs) -> tuple: """Replacement for Audio.crop — load full file then slice.""" waveform, sample_rate = _sf_load(file["audio"]) # Convert segment (seconds) to sample indices start_sample = int(segment.start * sample_rate) end_sample = int(segment.end * sample_rate) # Clamp to bounds start_sample = max(0, start_sample) end_sample = min(waveform.shape[-1], end_sample) cropped = waveform[:, start_sample:end_sample] return cropped, sample_rate Audio.__call__ = _soundfile_call # type: ignore[assignment] Audio.crop = _soundfile_crop # type: ignore[assignment] print("[sidecar] Patched pyannote Audio to use soundfile", file=sys.stderr, flush=True) except Exception as e: print(f"[sidecar] Warning: Could not patch pyannote Audio: {e}", file=sys.stderr, flush=True) def _ensure_wav(file_path: str) -> tuple[str, str | None]: """Convert audio to 16kHz mono WAV if needed. pyannote.audio v4.0.4 has a bug where its AudioDecoder returns duration=None for some formats (FLAC, etc.), causing crashes. Converting to WAV ensures the duration header is always present. Returns: (path_to_use, temp_path_or_None) If conversion was needed, temp_path is the WAV file to clean up. """ ext = Path(file_path).suffix.lower() if ext == ".wav": return file_path, None tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp.close() try: subprocess.run( [ get_ffmpeg_path(), "-y", "-i", file_path, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", tmp.name, ], check=True, capture_output=True, ) print( f"[sidecar] Converted {ext} to WAV for diarization", file=sys.stderr, flush=True, ) return tmp.name, tmp.name except (subprocess.CalledProcessError, FileNotFoundError) as e: # ffmpeg not available or failed — try original file and hope for the best print( f"[sidecar] WAV conversion failed ({e}), using original file", file=sys.stderr, flush=True, ) os.unlink(tmp.name) return file_path, None @dataclass class SpeakerSegment: """A time span assigned to a speaker.""" speaker: str start_ms: int end_ms: int @dataclass class DiarizationResult: """Full diarization output.""" speaker_segments: list[SpeakerSegment] = field(default_factory=list) num_speakers: int = 0 speakers: list[str] = field(default_factory=list) class DiarizeService: """Handles speaker diarization via pyannote.audio.""" def __init__(self) -> None: self._pipeline: Any = None def _ensure_pipeline(self, hf_token: str | None = None) -> Any: """Load the pyannote diarization pipeline (lazy).""" if self._pipeline is not None: return self._pipeline print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True) # Use token from argument, fall back to environment variable if not hf_token: hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None # Persist token globally so ALL huggingface_hub sub-downloads use auth. # Pyannote has internal dependencies that don't forward the token= param. if hf_token: os.environ["HF_TOKEN"] = hf_token import huggingface_hub huggingface_hub.login(token=hf_token, add_to_git_credential=False) models = [ "pyannote/speaker-diarization-3.1", "pyannote/speaker-diarization", ] last_error: Exception | None = None _patch_pyannote_audio() for model_name in models: try: from pyannote.audio import Pipeline self._pipeline = Pipeline.from_pretrained(model_name, token=hf_token) print(f"[sidecar] Loaded diarization model: {model_name}", file=sys.stderr, flush=True) # Move pipeline to GPU if available try: import torch if torch.cuda.is_available(): self._pipeline = self._pipeline.to(torch.device("cuda")) print(f"[sidecar] Diarization pipeline moved to GPU", file=sys.stderr, flush=True) except Exception as e: print(f"[sidecar] GPU not available for diarization: {e}", file=sys.stderr, flush=True) return self._pipeline except Exception as e: last_error = e print( f"[sidecar] Warning: Could not load {model_name}: {e}", file=sys.stderr, flush=True, ) raise RuntimeError( "pyannote.audio pipeline not available. " "You may need to accept the model license at " "https://huggingface.co/pyannote/speaker-diarization-3.1 " "and set a HF_TOKEN environment variable." ) from last_error def diarize( self, request_id: str, file_path: str, num_speakers: int | None = None, min_speakers: int | None = None, max_speakers: int | None = None, hf_token: str | None = None, audio_duration_sec: float | None = None, ) -> DiarizationResult: """Run speaker diarization on an audio file. Args: request_id: IPC request ID for progress messages. file_path: Path to audio file. num_speakers: Exact number of speakers (if known). min_speakers: Minimum expected speakers. max_speakers: Maximum expected speakers. Returns: DiarizationResult with speaker segments. """ write_message( progress_message(request_id, 0, "loading_diarization", "Loading diarization model...") ) pipeline = self._ensure_pipeline(hf_token=hf_token) write_message( progress_message(request_id, 20, "diarizing", "Running speaker diarization...") ) start_time = time.time() # Build kwargs for speaker constraints kwargs: dict[str, Any] = {} if num_speakers is not None: kwargs["num_speakers"] = num_speakers if min_speakers is not None: kwargs["min_speakers"] = min_speakers if max_speakers is not None: kwargs["max_speakers"] = max_speakers # Convert to WAV to work around pyannote v4.0.4 duration bug audio_path, temp_wav = _ensure_wav(file_path) print( f"[sidecar] Running diarization on {audio_path} with kwargs: {kwargs}", file=sys.stderr, flush=True, ) # Run diarization in background thread for progress reporting result_holder: list = [None] error_holder: list[Exception | None] = [None] done_event = threading.Event() def _run(): try: result_holder[0] = pipeline(audio_path, **kwargs) except Exception as e: error_holder[0] = e finally: done_event.set() thread = threading.Thread(target=_run, daemon=True) thread.start() elapsed = 0.0 estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0 while not done_event.wait(timeout=2.0): elapsed += 2.0 pct = min(20 + int((elapsed / estimated_total) * 65), 85) write_message(progress_message( request_id, pct, "diarizing", f"Analyzing speakers ({int(elapsed)}s elapsed)...")) thread.join() # Clean up temp file if temp_wav: os.unlink(temp_wav) if error_holder[0] is not None: raise error_holder[0] raw_result = result_holder[0] # pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly if hasattr(raw_result, "speaker_diarization"): diarization = raw_result.speaker_diarization else: diarization = raw_result # Convert pyannote output to our format result = DiarizationResult() seen_speakers: set[str] = set() for turn, _, speaker in diarization.itertracks(yield_label=True): result.speaker_segments.append( SpeakerSegment( speaker=speaker, start_ms=int(turn.start * 1000), end_ms=int(turn.end * 1000), ) ) seen_speakers.add(speaker) result.speakers = sorted(seen_speakers) result.num_speakers = len(seen_speakers) elapsed = time.time() - start_time print( f"[sidecar] Diarization complete: {result.num_speakers} speakers, " f"{len(result.speaker_segments)} segments in {elapsed:.1f}s", file=sys.stderr, flush=True, ) write_message( progress_message(request_id, 100, "done", "Diarization complete") ) return result def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]: """Convert DiarizationResult to IPC payload dict.""" return { "speaker_segments": [ { "speaker": seg.speaker, "start_ms": seg.start_ms, "end_ms": seg.end_ms, } for seg in result.speaker_segments ], "num_speakers": result.num_speakers, "speakers": result.speakers, }