diff --git a/python/tests/test_diarize.py b/python/tests/test_diarize.py index 3ceccc6..377e297 100644 --- a/python/tests/test_diarize.py +++ b/python/tests/test_diarize.py @@ -1,7 +1,13 @@ """Tests for diarization service data structures and payload conversion.""" +import time +from unittest.mock import MagicMock, patch + +import pytest + from voice_to_notes.services.diarize import ( DiarizationResult, + DiarizeService, SpeakerSegment, diarization_to_payload, ) @@ -31,3 +37,73 @@ def test_diarization_to_payload_empty(): assert payload["num_speakers"] == 0 assert payload["speaker_segments"] == [] assert payload["speakers"] == [] + + +def test_diarize_threading_progress(monkeypatch): + """Test that diarization emits progress while running in background thread.""" + # Track written messages + written_messages = [] + def mock_write(msg): + written_messages.append(msg) + + # Mock pipeline that takes ~5 seconds + def slow_pipeline(file_path, **kwargs): + time.sleep(5) + # Return a mock diarization result + mock_result = MagicMock() + mock_track = MagicMock() + mock_track.start = 0.0 + mock_track.end = 5.0 + mock_result.itertracks = MagicMock(return_value=[(mock_track, None, "SPEAKER_00")]) + return mock_result + + mock_pipeline_obj = MagicMock() + mock_pipeline_obj.side_effect = slow_pipeline + + service = DiarizeService() + service._pipeline = mock_pipeline_obj + + with patch("voice_to_notes.services.diarize.write_message", mock_write): + result = service.diarize( + request_id="req-1", + file_path="/fake/audio.wav", + audio_duration_sec=60.0, + ) + + # Filter for diarizing progress messages (not loading_diarization or done) + diarizing_msgs = [ + m for m in written_messages + if m.type == "progress" and m.payload.get("stage") == "diarizing" + and "elapsed" in m.payload.get("message", "") + ] + + # Should have at least 1 progress message (5s sleep / 2s interval = ~2 messages) + assert len(diarizing_msgs) >= 1, ( + f"Expected at least 1 diarizing progress message, got {len(diarizing_msgs)}" + ) + + # Progress percent should be between 20 and 85 + for msg in diarizing_msgs: + pct = msg.payload["percent"] + assert 20 <= pct <= 85, f"Progress {pct} out of expected range 20-85" + + # Result should be valid + assert result.num_speakers == 1 + assert result.speakers == ["SPEAKER_00"] + + +def test_diarize_threading_error_propagation(monkeypatch): + """Test that errors from the background thread are properly raised.""" + mock_pipeline_obj = MagicMock() + mock_pipeline_obj.side_effect = RuntimeError("Pipeline crashed") + + service = DiarizeService() + service._pipeline = mock_pipeline_obj + + with patch("voice_to_notes.services.diarize.write_message", lambda m: None): + with pytest.raises(RuntimeError, match="Pipeline crashed"): + service.diarize( + request_id="req-1", + file_path="/fake/audio.wav", + audio_duration_sec=30.0, + ) diff --git a/python/voice_to_notes/services/diarize.py b/python/voice_to_notes/services/diarize.py index 692b32f..b55ff65 100644 --- a/python/voice_to_notes/services/diarize.py +++ b/python/voice_to_notes/services/diarize.py @@ -6,6 +6,7 @@ import os import subprocess import sys import tempfile +import threading import time from dataclasses import dataclass, field from pathlib import Path @@ -141,6 +142,7 @@ class DiarizeService: min_speakers: int | None = None, max_speakers: int | None = None, hf_token: str | None = None, + audio_duration_sec: float | None = None, ) -> DiarizationResult: """Run speaker diarization on an audio file. @@ -184,12 +186,40 @@ class DiarizeService: flush=True, ) - # Run diarization - try: - raw_result = pipeline(audio_path, **kwargs) - finally: - if temp_wav: - os.unlink(temp_wav) + # Run diarization in background thread for progress reporting + result_holder: list = [None] + error_holder: list[Exception | None] = [None] + done_event = threading.Event() + + def _run(): + try: + result_holder[0] = pipeline(audio_path, **kwargs) + except Exception as e: + error_holder[0] = e + finally: + done_event.set() + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + + elapsed = 0.0 + estimated_total = max(audio_duration_sec * 0.5, 30.0) if audio_duration_sec else 120.0 + while not done_event.wait(timeout=2.0): + elapsed += 2.0 + pct = min(20 + int((elapsed / estimated_total) * 65), 85) + write_message(progress_message( + request_id, pct, "diarizing", + f"Analyzing speakers ({int(elapsed)}s elapsed)...")) + + thread.join() + + # Clean up temp file + if temp_wav: + os.unlink(temp_wav) + + if error_holder[0] is not None: + raise error_holder[0] + raw_result = result_holder[0] # pyannote 4.0+ returns DiarizeOutput; older versions return Annotation directly if hasattr(raw_result, "speaker_diarization"): diff --git a/python/voice_to_notes/services/pipeline.py b/python/voice_to_notes/services/pipeline.py index 3d84a81..281db16 100644 --- a/python/voice_to_notes/services/pipeline.py +++ b/python/voice_to_notes/services/pipeline.py @@ -139,6 +139,7 @@ class PipelineService: min_speakers=min_speakers, max_speakers=max_speakers, hf_token=hf_token, + audio_duration_sec=transcription.duration_ms / 1000.0, ) except Exception as e: import traceback