Fix diarization performance for long files + better progress
Some checks failed
Build Sidecars / Bump sidecar version and tag (push) Successful in 11s
Release / Bump version and tag (push) Successful in 10s
Build Sidecars / Build Sidecar (macOS) (push) Successful in 4m0s
Release / Build App (macOS) (push) Successful in 1m16s
Release / Build App (Linux) (push) Has been cancelled
Release / Build App (Windows) (push) Has been cancelled
Build Sidecars / Build Sidecar (Linux) (push) Successful in 17m34s
Build Sidecars / Build Sidecar (Windows) (push) Successful in 28m9s

- Cache loaded audio in _sf_load() — previously the entire WAV file was
  re-read from disk for every 10s crop call. For a 3-hour file with
  1000+ chunks, this meant ~345GB of disk reads. Now read once, cached.
- Better progress messages for long files: show elapsed time in m:ss
  format, warn "(180min audio, this may take a while)" for files >10min
- Increased progress poll interval from 2s to 5s (less noise)
- Better time estimate: use 0.8x audio duration (was 0.5x)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Claude
2026-03-23 06:24:18 -07:00
parent 999bdaa671
commit 806586ae3d

View File

@@ -41,14 +41,23 @@ def _patch_pyannote_audio() -> None:
import torch import torch
from pyannote.audio.core.io import Audio from pyannote.audio.core.io import Audio
# Cache loaded audio to avoid re-reading the entire file for every crop call.
# For a 3-hour file, crop is called 1000+ times — without caching, each call
# reads ~345MB from disk.
_audio_cache: dict[str, tuple] = {}
def _sf_load(audio_path: str) -> tuple: def _sf_load(audio_path: str) -> tuple:
"""Load audio via soundfile, return (channels, samples) tensor + sample_rate.""" """Load audio via soundfile with caching."""
data, sample_rate = sf.read(str(audio_path), dtype="float32") key = str(audio_path)
if key in _audio_cache:
return _audio_cache[key]
data, sample_rate = sf.read(key, dtype="float32")
waveform = torch.from_numpy(np.array(data)) waveform = torch.from_numpy(np.array(data))
if waveform.ndim == 1: if waveform.ndim == 1:
waveform = waveform.unsqueeze(0) waveform = waveform.unsqueeze(0)
else: else:
waveform = waveform.T waveform = waveform.T
_audio_cache[key] = (waveform, sample_rate)
return waveform, sample_rate return waveform, sample_rate
def _soundfile_call(self, file: dict) -> tuple: def _soundfile_call(self, file: dict) -> tuple:
@@ -56,7 +65,7 @@ def _patch_pyannote_audio() -> None:
return _sf_load(file["audio"]) return _sf_load(file["audio"])
def _soundfile_crop(self, file: dict, segment, **kwargs) -> tuple: def _soundfile_crop(self, file: dict, segment, **kwargs) -> tuple:
"""Replacement for Audio.crop — load full file then slice. """Replacement for Audio.crop — load file once (cached) then slice.
Pads short segments with zeros to match the expected duration, Pads short segments with zeros to match the expected duration,
which pyannote requires for batched embedding extraction. which pyannote requires for batched embedding extraction.
@@ -279,13 +288,20 @@ class DiarizeService:
thread.start() thread.start()
elapsed = 0.0 elapsed = 0.0
estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0 estimated_total = max(audio_duration_sec * 0.8, 30.0) if audio_duration_sec else 120.0
while not done_event.wait(timeout=2.0): duration_str = ""
elapsed += 2.0 if audio_duration_sec and audio_duration_sec > 600:
mins = int(audio_duration_sec / 60)
duration_str = f" ({mins}min audio, this may take a while)"
while not done_event.wait(timeout=5.0):
elapsed += 5.0
pct = min(20 + int((elapsed / estimated_total) * 65), 85) pct = min(20 + int((elapsed / estimated_total) * 65), 85)
elapsed_min = int(elapsed / 60)
elapsed_sec = int(elapsed % 60)
time_str = f"{elapsed_min}m{elapsed_sec:02d}s" if elapsed_min > 0 else f"{int(elapsed)}s"
write_message(progress_message( write_message(progress_message(
request_id, pct, "diarizing", request_id, pct, "diarizing",
f"Analyzing speakers ({int(elapsed)}s elapsed)...")) f"Analyzing speakers ({time_str} elapsed){duration_str}"))
thread.join() thread.join()