From 03af5a189cbeb1ab876031fa399159beb29f59f7 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 20 Mar 2026 13:50:57 -0700 Subject: [PATCH] Run pyannote diarization in background thread with progress reporting Move the blocking pipeline() call to a daemon thread and emit estimated progress messages every 2 seconds from the main thread. The progress estimate uses audio duration to calibrate the expected total time. Also pass audio_duration_sec from PipelineService to DiarizeService. Co-Authored-By: Claude Opus 4.6 (1M context) --- python/tests/test_diarize.py | 76 ++++++++++++++++++++++ python/voice_to_notes/services/diarize.py | 35 +++++++++- python/voice_to_notes/services/pipeline.py | 1 + 3 files changed, 110 insertions(+), 2 deletions(-) diff --git a/python/tests/test_diarize.py b/python/tests/test_diarize.py index 3ceccc6..377e297 100644 --- a/python/tests/test_diarize.py +++ b/python/tests/test_diarize.py @@ -1,7 +1,13 @@ """Tests for diarization service data structures and payload conversion.""" +import time +from unittest.mock import MagicMock, patch + +import pytest + from voice_to_notes.services.diarize import ( DiarizationResult, + DiarizeService, SpeakerSegment, diarization_to_payload, ) @@ -31,3 +37,73 @@ def test_diarization_to_payload_empty(): assert payload["num_speakers"] == 0 assert payload["speaker_segments"] == [] assert payload["speakers"] == [] + + +def test_diarize_threading_progress(monkeypatch): + """Test that diarization emits progress while running in background thread.""" + # Track written messages + written_messages = [] + def mock_write(msg): + written_messages.append(msg) + + # Mock pipeline that takes ~5 seconds + def slow_pipeline(file_path, **kwargs): + time.sleep(5) + # Return a mock diarization result + mock_result = MagicMock() + mock_track = MagicMock() + mock_track.start = 0.0 + mock_track.end = 5.0 + mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")]) + return mock_result + + mock_pipeline_obj = MagicMock() + mock_pipeline_obj.side_effect = slow_pipeline + + service = DiarizeService() + service._pipeline = mock_pipeline_obj + + with patch("voice_to_notes.services.diarize.write_message", mock_write): + result = service.diarize( + request_id="req-1", + file_path="/fake/audio.wav", + audio_duration_sec=60.0, + ) + + # Filter for diarizing progress messages (not loading_diarization or done) + diarizing_msgs = [ + m for m in written_messages + if m.type == "progress" and m.payload.get("stage") == "diarizing" + and "elapsed" in m.payload.get("message", "") + ] + + # Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages) + assert len(diarizing_msgs) >= 1, ( + f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}" + ) + + # Progress percent should be between 20 and 85 + for msg in diarizing_msgs: + pct = msg.payload["percent"] + assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85" + + # Result should be valid + assert result.num_speakers == 1 + assert result.speakers == ["SPEAKER_00"] + + +def test_diarize_threading_error_propagation(monkeypatch): + """Test that errors from the background thread are properly raised.""" + mock_pipeline_obj = MagicMock() + mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed") + + service = DiarizeService() + service._pipeline = mock_pipeline_obj + + with patch("voice_to_notes.services.diarize.write_message", lambda m: None): + with pytest.raises(RuntimeError, match="Pipeline crashed"): + service.diarize( + request_id="req-1", + file_path="/fake/audio.wav", + audio_duration_sec=30.0, + ) diff --git a/python/voice_to_notes/services/diarize.py b/python/voice_to_notes/services/diarize.py index 201ca9c..f693f58 100644 --- a/python/voice_to_notes/services/diarize.py +++ b/python/voice_to_notes/services/diarize.py @@ -3,6 +3,7 @@ from __future__ import annotations import sys +import threading import time from dataclasses import dataclass, field from typing import Any @@ -82,6 +83,8 @@ class DiarizeService: num_speakers: int | None = None, min_speakers: int | None = None, max_speakers: int | None = None, + hf_token: str | None = None, + audio_duration_sec: float | None = None, ) -> DiarizationResult: """Run speaker diarization on an audio file. @@ -116,8 +119,36 @@ class DiarizeService: if max_speakers is not None: kwargs["max_speakers"] = max_speakers - # Run diarization - diarization = pipeline(file_path, **kwargs) + # Run diarization in background thread for progress reporting + result_holder: list = [None] + error_holder: list[Exception | None] = [None] + done_event = threading.Event() + + def _run(): + try: + result_holder[0] = pipeline(file_path, **kwargs) + except Exception as e: + error_holder[0] = e + finally: + done_event.set() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + + elapsed = 0.0 + estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0 + while not done_event.wait(timeout=2.0): + elapsed += 2.0 + pct = min(20 + int((elapsed / estimated_total) * 65), 85) + write_message(progress_message( + request_id, pct, "diarizing", + f"Analyzing speakers ({int(elapsed)}s elapsed)...")) + + thread.join() + + if error_holder[0] is not None: + raise error_holder[0] + diarization = result_holder[0] # Convert pyannote output to our format result = DiarizationResult() diff --git a/python/voice_to_notes/services/pipeline.py b/python/voice_to_notes/services/pipeline.py index 2d1f66b..110be37 100644 --- a/python/voice_to_notes/services/pipeline.py +++ b/python/voice_to_notes/services/pipeline.py @@ -121,6 +121,7 @@ class PipelineService: num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers, + audio_duration_sec=transcription.duration_ms / 1000.0, ) # Step 3: Merge