Some checks failed
Build Sidecars / Bump sidecar version and tag (push) Successful in 3s
Release / Bump version and tag (push) Successful in 3s
Build Sidecars / Build Sidecar (macOS) (push) Successful in 3m58s
Release / Build App (macOS) (push) Successful in 1m20s
Release / Build App (Linux) (push) Has been cancelled
Release / Build App (Windows) (push) Has been cancelled
Build Sidecars / Build Sidecar (Linux) (push) Successful in 13m41s
Build Sidecars / Build Sidecar (Windows) (push) Successful in 34m33s
torchaudio 2.10 unconditionally delegates load() to torchcodec, ignoring the backend parameter. Since torchcodec is excluded from PyInstaller, this broke our pyannote Audio monkey-patch. Fix: replace torchaudio.load() with soundfile.read() + torch.from_numpy(). soundfile handles WAV natively (audio is pre-converted to WAV), has no torchcodec dependency, and is already a transitive dependency. Also added soundfile to PyInstaller hiddenimports. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
325 lines
11 KiB
Python
325 lines
11 KiB
Python
"""Diarization service — pyannote.audio speaker identification."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
# Disable pyannote telemetry — it has a bug in v4.0.4 where
|
|
# np.isfinite(None) crashes when max_speakers is not set.
|
|
os.environ.setdefault("PYANNOTE_METRICS_ENABLED", "false")
|
|
|
|
from voice_to_notes.utils.ffmpeg import get_ffmpeg_path
|
|
from voice_to_notes.ipc.messages import progress_message
|
|
from voice_to_notes.ipc.protocol import write_message
|
|
|
|
_patched = False
|
|
|
|
|
|
def _patch_pyannote_audio() -> None:
|
|
"""Monkey-patch pyannote.audio.core.io.Audio to use torchaudio.
|
|
|
|
pyannote.audio has a bug where AudioDecoder (from torchcodec) is used
|
|
unconditionally even when torchcodec is not installed, causing NameError.
|
|
This replaces the Audio.__call__ method with a torchaudio-based version.
|
|
"""
|
|
global _patched
|
|
if _patched:
|
|
return
|
|
_patched = True
|
|
|
|
try:
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
from pyannote.audio.core.io import Audio
|
|
|
|
def _soundfile_call(self: Audio, file: dict) -> tuple:
|
|
"""Load audio via soundfile (bypasses torchaudio/torchcodec entirely)."""
|
|
audio_path = str(file["audio"])
|
|
data, sample_rate = sf.read(audio_path, dtype="float32")
|
|
# soundfile returns (samples,) for mono, (samples, channels) for stereo
|
|
# pyannote expects (channels, samples) torch tensor
|
|
waveform = torch.from_numpy(np.array(data))
|
|
if waveform.ndim == 1:
|
|
waveform = waveform.unsqueeze(0) # (samples,) -> (1, samples)
|
|
else:
|
|
waveform = waveform.T # (samples, channels) -> (channels, samples)
|
|
return waveform, sample_rate
|
|
|
|
Audio.__call__ = _soundfile_call # type: ignore[assignment]
|
|
print("[sidecar] Patched pyannote Audio to use soundfile", file=sys.stderr, flush=True)
|
|
except Exception as e:
|
|
print(f"[sidecar] Warning: Could not patch pyannote Audio: {e}", file=sys.stderr, flush=True)
|
|
|
|
|
|
def _ensure_wav(file_path: str) -> tuple[str, str | None]:
|
|
"""Convert audio to 16kHz mono WAV if needed.
|
|
|
|
pyannote.audio v4.0.4 has a bug where its AudioDecoder returns
|
|
duration=None for some formats (FLAC, etc.), causing crashes.
|
|
Converting to WAV ensures the duration header is always present.
|
|
|
|
Returns:
|
|
(path_to_use, temp_path_or_None)
|
|
If conversion was needed, temp_path is the WAV file to clean up.
|
|
"""
|
|
ext = Path(file_path).suffix.lower()
|
|
if ext == ".wav":
|
|
return file_path, None
|
|
|
|
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
|
tmp.close()
|
|
try:
|
|
subprocess.run(
|
|
[
|
|
get_ffmpeg_path(), "-y", "-i", file_path,
|
|
"-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le",
|
|
tmp.name,
|
|
],
|
|
check=True,
|
|
capture_output=True,
|
|
)
|
|
print(
|
|
f"[sidecar] Converted {ext} to WAV for diarization",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
return tmp.name, tmp.name
|
|
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
|
# ffmpeg not available or failed — try original file and hope for the best
|
|
print(
|
|
f"[sidecar] WAV conversion failed ({e}), using original file",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
os.unlink(tmp.name)
|
|
return file_path, None
|
|
|
|
|
|
@dataclass
|
|
class SpeakerSegment:
|
|
"""A time span assigned to a speaker."""
|
|
|
|
speaker: str
|
|
start_ms: int
|
|
end_ms: int
|
|
|
|
|
|
@dataclass
|
|
class DiarizationResult:
|
|
"""Full diarization output."""
|
|
|
|
speaker_segments: list[SpeakerSegment] = field(default_factory=list)
|
|
num_speakers: int = 0
|
|
speakers: list[str] = field(default_factory=list)
|
|
|
|
|
|
class DiarizeService:
|
|
"""Handles speaker diarization via pyannote.audio."""
|
|
|
|
def __init__(self) -> None:
|
|
self._pipeline: Any = None
|
|
|
|
def _ensure_pipeline(self, hf_token: str | None = None) -> Any:
|
|
"""Load the pyannote diarization pipeline (lazy)."""
|
|
if self._pipeline is not None:
|
|
return self._pipeline
|
|
|
|
print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True)
|
|
|
|
# Use token from argument, fall back to environment variable
|
|
if not hf_token:
|
|
hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None
|
|
|
|
# Persist token globally so ALL huggingface_hub sub-downloads use auth.
|
|
# Pyannote has internal dependencies that don't forward the token= param.
|
|
if hf_token:
|
|
os.environ["HF_TOKEN"] = hf_token
|
|
import huggingface_hub
|
|
huggingface_hub.login(token=hf_token, add_to_git_credential=False)
|
|
|
|
models = [
|
|
"pyannote/speaker-diarization-3.1",
|
|
"pyannote/speaker-diarization",
|
|
]
|
|
|
|
last_error: Exception | None = None
|
|
_patch_pyannote_audio()
|
|
for model_name in models:
|
|
try:
|
|
from pyannote.audio import Pipeline
|
|
|
|
self._pipeline = Pipeline.from_pretrained(model_name, token=hf_token)
|
|
print(f"[sidecar] Loaded diarization model: {model_name}", file=sys.stderr, flush=True)
|
|
# Move pipeline to GPU if available
|
|
try:
|
|
import torch
|
|
if torch.cuda.is_available():
|
|
self._pipeline = self._pipeline.to(torch.device("cuda"))
|
|
print(f"[sidecar] Diarization pipeline moved to GPU", file=sys.stderr, flush=True)
|
|
except Exception as e:
|
|
print(f"[sidecar] GPU not available for diarization: {e}", file=sys.stderr, flush=True)
|
|
return self._pipeline
|
|
except Exception as e:
|
|
last_error = e
|
|
print(
|
|
f"[sidecar] Warning: Could not load {model_name}: {e}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
raise RuntimeError(
|
|
"pyannote.audio pipeline not available. "
|
|
"You may need to accept the model license at "
|
|
"https://huggingface.co/pyannote/speaker-diarization-3.1 "
|
|
"and set a HF_TOKEN environment variable."
|
|
) from last_error
|
|
|
|
def diarize(
|
|
self,
|
|
request_id: str,
|
|
file_path: str,
|
|
num_speakers: int | None = None,
|
|
min_speakers: int | None = None,
|
|
max_speakers: int | None = None,
|
|
hf_token: str | None = None,
|
|
audio_duration_sec: float | None = None,
|
|
) -> DiarizationResult:
|
|
"""Run speaker diarization on an audio file.
|
|
|
|
Args:
|
|
request_id: IPC request ID for progress messages.
|
|
file_path: Path to audio file.
|
|
num_speakers: Exact number of speakers (if known).
|
|
min_speakers: Minimum expected speakers.
|
|
max_speakers: Maximum expected speakers.
|
|
|
|
Returns:
|
|
DiarizationResult with speaker segments.
|
|
"""
|
|
write_message(
|
|
progress_message(request_id, 0, "loading_diarization", "Loading diarization model...")
|
|
)
|
|
|
|
pipeline = self._ensure_pipeline(hf_token=hf_token)
|
|
|
|
write_message(
|
|
progress_message(request_id, 20, "diarizing", "Running speaker diarization...")
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
# Build kwargs for speaker constraints
|
|
kwargs: dict[str, Any] = {}
|
|
if num_speakers is not None:
|
|
kwargs["num_speakers"] = num_speakers
|
|
if min_speakers is not None:
|
|
kwargs["min_speakers"] = min_speakers
|
|
if max_speakers is not None:
|
|
kwargs["max_speakers"] = max_speakers
|
|
|
|
# Convert to WAV to work around pyannote v4.0.4 duration bug
|
|
audio_path, temp_wav = _ensure_wav(file_path)
|
|
|
|
print(
|
|
f"[sidecar] Running diarization on {audio_path} with kwargs: {kwargs}",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
# Run diarization in background thread for progress reporting
|
|
result_holder: list = [None]
|
|
error_holder: list[Exception | None] = [None]
|
|
done_event = threading.Event()
|
|
|
|
def _run():
|
|
try:
|
|
result_holder[0] = pipeline(audio_path, **kwargs)
|
|
except Exception as e:
|
|
error_holder[0] = e
|
|
finally:
|
|
done_event.set()
|
|
|
|
thread = threading.Thread(target=_run, daemon=True)
|
|
thread.start()
|
|
|
|
elapsed = 0.0
|
|
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0
|
|
while not done_event.wait(timeout=2.0):
|
|
elapsed += 2.0
|
|
pct = min(20 + int((elapsed / estimated_total) * 65), 85)
|
|
write_message(progress_message(
|
|
request_id, pct, "diarizing",
|
|
f"Analyzing speakers ({int(elapsed)}s elapsed)..."))
|
|
|
|
thread.join()
|
|
|
|
# Clean up temp file
|
|
if temp_wav:
|
|
os.unlink(temp_wav)
|
|
|
|
if error_holder[0] is not None:
|
|
raise error_holder[0]
|
|
raw_result = result_holder[0]
|
|
|
|
# pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly
|
|
if hasattr(raw_result, "speaker_diarization"):
|
|
diarization = raw_result.speaker_diarization
|
|
else:
|
|
diarization = raw_result
|
|
|
|
# Convert pyannote output to our format
|
|
result = DiarizationResult()
|
|
seen_speakers: set[str] = set()
|
|
|
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
|
result.speaker_segments.append(
|
|
SpeakerSegment(
|
|
speaker=speaker,
|
|
start_ms=int(turn.start * 1000),
|
|
end_ms=int(turn.end * 1000),
|
|
)
|
|
)
|
|
seen_speakers.add(speaker)
|
|
|
|
result.speakers = sorted(seen_speakers)
|
|
result.num_speakers = len(seen_speakers)
|
|
|
|
elapsed = time.time() - start_time
|
|
print(
|
|
f"[sidecar] Diarization complete: {result.num_speakers} speakers, "
|
|
f"{len(result.speaker_segments)} segments in {elapsed:.1f}s",
|
|
file=sys.stderr,
|
|
flush=True,
|
|
)
|
|
|
|
write_message(
|
|
progress_message(request_id, 100, "done", "Diarization complete")
|
|
)
|
|
|
|
return result
|
|
|
|
|
|
def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]:
|
|
"""Convert DiarizationResult to IPC payload dict."""
|
|
return {
|
|
"speaker_segments": [
|
|
{
|
|
"speaker": seg.speaker,
|
|
"start_ms": seg.start_ms,
|
|
"end_ms": seg.end_ms,
|
|
}
|
|
for seg in result.speaker_segments
|
|
],
|
|
"num_speakers": result.num_speakers,
|
|
"speakers": result.speakers,
|
|
}
|