diff --git a/python/voice_to_notes/services/diarize.py b/python/voice_to_notes/services/diarize.py index 0f8691a..49a88b9 100644 --- a/python/voice_to_notes/services/diarize.py +++ b/python/voice_to_notes/services/diarize.py @@ -41,20 +41,34 @@ def _patch_pyannote_audio() -> None: import torch from pyannote.audio.core.io import Audio - def _soundfile_call(self: Audio, file: dict) -> tuple: - """Load audio via soundfile (bypasses torchaudio/torchcodec entirely).""" - audio_path = str(file["audio"]) - data, sample_rate = sf.read(audio_path, dtype="float32") - # soundfile returns (samples,) for mono, (samples, channels) for stereo - # pyannote expects (channels, samples) torch tensor + def _sf_load(audio_path: str) -> tuple: + """Load audio via soundfile, return (channels, samples) tensor + sample_rate.""" + data, sample_rate = sf.read(str(audio_path), dtype="float32") waveform = torch.from_numpy(np.array(data)) if waveform.ndim == 1: - waveform = waveform.unsqueeze(0) # (samples,) -> (1, samples) + waveform = waveform.unsqueeze(0) else: - waveform = waveform.T # (samples, channels) -> (channels, samples) + waveform = waveform.T return waveform, sample_rate + def _soundfile_call(self, file: dict) -> tuple: + """Replacement for Audio.__call__.""" + return _sf_load(file["audio"]) + + def _soundfile_crop(self, file: dict, segment, **kwargs) -> tuple: + """Replacement for Audio.crop — load full file then slice.""" + waveform, sample_rate = _sf_load(file["audio"]) + # Convert segment (seconds) to sample indices + start_sample = int(segment.start * sample_rate) + end_sample = int(segment.end * sample_rate) + # Clamp to bounds + start_sample = max(0, start_sample) + end_sample = min(waveform.shape[-1], end_sample) + cropped = waveform[:, start_sample:end_sample] + return cropped, sample_rate + Audio.__call__ = _soundfile_call # type: ignore[assignment] + Audio.crop = _soundfile_crop # type: ignore[assignment] print("[sidecar] Patched pyannote Audio to use soundfile", file=sys.stderr, flush=True) except Exception as e: print(f"[sidecar] Warning: Could not patch pyannote Audio: {e}", file=sys.stderr, flush=True)