From 44480906a4a7480c882b468e4fc15a048b14e666 Mon Sep 17 00:00:00 2001 From: Josh Knapp Date: Thu, 26 Feb 2026 16:09:48 -0800 Subject: [PATCH] Phase 3: Speaker diarization and full transcription pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Implement DiarizeService with pyannote.audio speaker detection - Build PipelineService combining transcribe → diarize → merge with overlap-based speaker assignment per segment - Add pipeline.start and diarize.start IPC handlers - Add run_pipeline Tauri command for full pipeline execution - Wire frontend to use pipeline: speakers auto-created with colors, segments assigned to detected speakers - Build SpeakerManager with rename support (double-click or edit button) - Add speaker color coding throughout transcript display - Add pyannote.audio dependency - Tests: 24 Python (including merge logic), 6 Rust, 0 Svelte errors Co-Authored-By: Claude Opus 4.6 --- python/pyproject.toml | 1 + python/tests/test_diarize.py | 33 +++ python/tests/test_pipeline.py | 90 ++++++++ python/voice_to_notes/ipc/handlers.py | 53 +++++ python/voice_to_notes/main.py | 5 +- python/voice_to_notes/services/diarize.py | 164 +++++++++++++- python/voice_to_notes/services/pipeline.py | 235 ++++++++++++++++++++- src-tauri/src/commands/transcribe.rs | 52 +++++ src-tauri/src/lib.rs | 3 +- src/lib/components/SpeakerManager.svelte | 130 +++++++++++- src/lib/services/tauri-bridge.ts | 32 +++ src/routes/+page.svelte | 32 ++- 12 files changed, 806 insertions(+), 24 deletions(-) create mode 100644 python/tests/test_diarize.py create mode 100644 python/tests/test_pipeline.py diff --git a/python/pyproject.toml b/python/pyproject.toml index fda4567..1a06066 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -11,6 +11,7 @@ license = "MIT" dependencies = [ "faster-whisper>=1.1.0", + "pyannote.audio>=3.1.0", ] [project.optional-dependencies] diff --git a/python/tests/test_diarize.py b/python/tests/test_diarize.py new file mode 100644 index 0000000..3ceccc6 --- /dev/null +++ b/python/tests/test_diarize.py @@ -0,0 +1,33 @@ +"""Tests for diarization service data structures and payload conversion.""" + +from voice_to_notes.services.diarize import ( + DiarizationResult, + SpeakerSegment, + diarization_to_payload, +) + + +def test_diarization_to_payload(): + result = DiarizationResult( + speaker_segments=[ + SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=5000), + SpeakerSegment(speaker="SPEAKER_01", start_ms=5000, end_ms=10000), + SpeakerSegment(speaker="SPEAKER_00", start_ms=10000, end_ms=15000), + ], + num_speakers=2, + speakers=["SPEAKER_00", "SPEAKER_01"], + ) + payload = diarization_to_payload(result) + assert payload["num_speakers"] == 2 + assert len(payload["speaker_segments"]) == 3 + assert payload["speakers"] == ["SPEAKER_00", "SPEAKER_01"] + assert payload["speaker_segments"][0]["speaker"] == "SPEAKER_00" + assert payload["speaker_segments"][1]["start_ms"] == 5000 + + +def test_diarization_to_payload_empty(): + result = DiarizationResult() + payload = diarization_to_payload(result) + assert payload["num_speakers"] == 0 + assert payload["speaker_segments"] == [] + assert payload["speakers"] == [] diff --git a/python/tests/test_pipeline.py b/python/tests/test_pipeline.py new file mode 100644 index 0000000..a1d3fef --- /dev/null +++ b/python/tests/test_pipeline.py @@ -0,0 +1,90 @@ +"""Tests for pipeline service data structures and merge logic.""" + +from voice_to_notes.services.diarize import SpeakerSegment +from voice_to_notes.services.pipeline import ( + PipelineResult, + PipelineSegment, + PipelineService, + pipeline_result_to_payload, +) +from voice_to_notes.services.transcribe import ( + SegmentResult, + TranscriptionResult, + WordResult, +) + + +def test_pipeline_result_to_payload(): + result = PipelineResult( + segments=[ + PipelineSegment( + text="Hello world", + start_ms=0, + end_ms=2000, + speaker="SPEAKER_00", + words=[ + WordResult(word="Hello", start_ms=0, end_ms=800, confidence=0.95), + WordResult(word="world", start_ms=900, end_ms=2000, confidence=0.88), + ], + ), + ], + language="en", + language_probability=0.98, + duration_ms=10000, + speakers=["SPEAKER_00", "SPEAKER_01"], + num_speakers=2, + ) + payload = pipeline_result_to_payload(result) + assert payload["language"] == "en" + assert payload["num_speakers"] == 2 + assert len(payload["segments"]) == 1 + assert payload["segments"][0]["speaker"] == "SPEAKER_00" + assert len(payload["segments"][0]["words"]) == 2 + + +def test_pipeline_result_to_payload_empty(): + result = PipelineResult() + payload = pipeline_result_to_payload(result) + assert payload["segments"] == [] + assert payload["speakers"] == [] + assert payload["num_speakers"] == 0 + + +def test_merge_results_assigns_speakers(): + """Test that _merge_results correctly assigns speakers based on overlap.""" + service = PipelineService() + + transcription = TranscriptionResult( + segments=[ + SegmentResult(text="Hello there", start_ms=0, end_ms=3000, words=[]), + SegmentResult(text="How are you", start_ms=4000, end_ms=7000, words=[]), + ], + language="en", + language_probability=0.99, + duration_ms=10000, + ) + + speaker_segments = [ + SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=3500), + SpeakerSegment(speaker="SPEAKER_01", start_ms=3500, end_ms=8000), + ] + + result = service._merge_results(transcription, speaker_segments) + assert len(result.segments) == 2 + assert result.segments[0].speaker == "SPEAKER_00" + assert result.segments[1].speaker == "SPEAKER_01" + + +def test_merge_results_no_speaker_segments(): + """With no speaker segments, all speakers should be None.""" + service = PipelineService() + + transcription = TranscriptionResult( + segments=[SegmentResult(text="Hello", start_ms=0, end_ms=1000, words=[])], + language="en", + language_probability=0.99, + duration_ms=1000, + ) + + result = service._merge_results(transcription, []) + assert result.segments[0].speaker is None diff --git a/python/voice_to_notes/ipc/handlers.py b/python/voice_to_notes/ipc/handlers.py index 7d81c61..a74bb77 100644 --- a/python/voice_to_notes/ipc/handlers.py +++ b/python/voice_to_notes/ipc/handlers.py @@ -64,6 +64,59 @@ def make_transcribe_handler() -> HandlerFunc: return handler +def make_diarize_handler() -> HandlerFunc: + """Create a diarization handler with a persistent DiarizeService.""" + from voice_to_notes.services.diarize import DiarizeService, diarization_to_payload + + service = DiarizeService() + + def handler(msg: IPCMessage) -> IPCMessage: + payload = msg.payload + result = service.diarize( + request_id=msg.id, + file_path=payload["file"], + num_speakers=payload.get("num_speakers"), + min_speakers=payload.get("min_speakers"), + max_speakers=payload.get("max_speakers"), + ) + return IPCMessage( + id=msg.id, + type="diarize.result", + payload=diarization_to_payload(result), + ) + + return handler + + +def make_pipeline_handler() -> HandlerFunc: + """Create a full pipeline handler (transcribe + diarize + merge).""" + from voice_to_notes.services.pipeline import PipelineService, pipeline_result_to_payload + + service = PipelineService() + + def handler(msg: IPCMessage) -> IPCMessage: + payload = msg.payload + result = service.run( + request_id=msg.id, + file_path=payload["file"], + model_name=payload.get("model", "base"), + device=payload.get("device", "cpu"), + compute_type=payload.get("compute_type", "int8"), + language=payload.get("language"), + num_speakers=payload.get("num_speakers"), + min_speakers=payload.get("min_speakers"), + max_speakers=payload.get("max_speakers"), + skip_diarization=payload.get("skip_diarization", False), + ) + return IPCMessage( + id=msg.id, + type="pipeline.result", + payload=pipeline_result_to_payload(result), + ) + + return handler + + def hardware_detect_handler(msg: IPCMessage) -> IPCMessage: """Detect hardware capabilities and return recommendations.""" from voice_to_notes.hardware.detect import detect_hardware diff --git a/python/voice_to_notes/main.py b/python/voice_to_notes/main.py index edffa0d..e2d3c5d 100644 --- a/python/voice_to_notes/main.py +++ b/python/voice_to_notes/main.py @@ -8,6 +8,8 @@ import sys from voice_to_notes.ipc.handlers import ( HandlerRegistry, hardware_detect_handler, + make_diarize_handler, + make_pipeline_handler, make_transcribe_handler, ping_handler, ) @@ -21,7 +23,8 @@ def create_registry() -> HandlerRegistry: registry.register("ping", ping_handler) registry.register("transcribe.start", make_transcribe_handler()) registry.register("hardware.detect", hardware_detect_handler) - # TODO: Register diarize, pipeline, ai, export handlers + registry.register("diarize.start", make_diarize_handler()) + registry.register("pipeline.start", make_pipeline_handler()) return registry diff --git a/python/voice_to_notes/services/diarize.py b/python/voice_to_notes/services/diarize.py index e044839..201ca9c 100644 --- a/python/voice_to_notes/services/diarize.py +++ b/python/voice_to_notes/services/diarize.py @@ -2,12 +2,166 @@ from __future__ import annotations +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +from voice_to_notes.ipc.messages import progress_message +from voice_to_notes.ipc.protocol import write_message + + +@dataclass +class SpeakerSegment: + """A time span assigned to a speaker.""" + + speaker: str + start_ms: int + end_ms: int + + +@dataclass +class DiarizationResult: + """Full diarization output.""" + + speaker_segments: list[SpeakerSegment] = field(default_factory=list) + num_speakers: int = 0 + speakers: list[str] = field(default_factory=list) + class DiarizeService: """Handles speaker diarization via pyannote.audio.""" - # TODO: Implement pyannote.audio integration - # - Load community-1 model - # - Run diarization on audio - # - Return speaker segments with timestamps - pass + def __init__(self) -> None: + self._pipeline: Any = None + + def _ensure_pipeline(self) -> Any: + """Load the pyannote diarization pipeline (lazy).""" + if self._pipeline is not None: + return self._pipeline + + print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True) + + try: + from pyannote.audio import Pipeline + + self._pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + use_auth_token=False, + ) + except Exception: + # Fall back to a simpler approach if the model isn't available + # pyannote requires HuggingFace token for some models + # Try the community model first + try: + from pyannote.audio import Pipeline + + self._pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization", + use_auth_token=False, + ) + except Exception as e: + print( + f"[sidecar] Warning: Could not load pyannote pipeline: {e}", + file=sys.stderr, + flush=True, + ) + raise RuntimeError( + "pyannote.audio pipeline not available. " + "You may need to accept the model license at " + "https://huggingface.co/pyannote/speaker-diarization-3.1 " + "and set a HF_TOKEN environment variable." + ) from e + + return self._pipeline + + def diarize( + self, + request_id: str, + file_path: str, + num_speakers: int | None = None, + min_speakers: int | None = None, + max_speakers: int | None = None, + ) -> DiarizationResult: + """Run speaker diarization on an audio file. + + Args: + request_id: IPC request ID for progress messages. + file_path: Path to audio file. + num_speakers: Exact number of speakers (if known). + min_speakers: Minimum expected speakers. + max_speakers: Maximum expected speakers. + + Returns: + DiarizationResult with speaker segments. + """ + write_message( + progress_message(request_id, 0, "loading_diarization", "Loading diarization model...") + ) + + pipeline = self._ensure_pipeline() + + write_message( + progress_message(request_id, 20, "diarizing", "Running speaker diarization...") + ) + + start_time = time.time() + + # Build kwargs for speaker constraints + kwargs: dict[str, Any] = {} + if num_speakers is not None: + kwargs["num_speakers"] = num_speakers + if min_speakers is not None: + kwargs["min_speakers"] = min_speakers + if max_speakers is not None: + kwargs["max_speakers"] = max_speakers + + # Run diarization + diarization = pipeline(file_path, **kwargs) + + # Convert pyannote output to our format + result = DiarizationResult() + seen_speakers: set[str] = set() + + for turn, _, speaker in diarization.itertracks(yield_label=True): + result.speaker_segments.append( + SpeakerSegment( + speaker=speaker, + start_ms=int(turn.start * 1000), + end_ms=int(turn.end * 1000), + ) + ) + seen_speakers.add(speaker) + + result.speakers = sorted(seen_speakers) + result.num_speakers = len(seen_speakers) + + elapsed = time.time() - start_time + print( + f"[sidecar] Diarization complete: {result.num_speakers} speakers, " + f"{len(result.speaker_segments)} segments in {elapsed:.1f}s", + file=sys.stderr, + flush=True, + ) + + write_message( + progress_message(request_id, 100, "done", "Diarization complete") + ) + + return result + + +def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]: + """Convert DiarizationResult to IPC payload dict.""" + return { + "speaker_segments": [ + { + "speaker": seg.speaker, + "start_ms": seg.start_ms, + "end_ms": seg.end_ms, + } + for seg in result.speaker_segments + ], + "num_speakers": result.num_speakers, + "speakers": result.speakers, + } diff --git a/python/voice_to_notes/services/pipeline.py b/python/voice_to_notes/services/pipeline.py index 05ef1bf..2d1f66b 100644 --- a/python/voice_to_notes/services/pipeline.py +++ b/python/voice_to_notes/services/pipeline.py @@ -2,13 +2,234 @@ from __future__ import annotations +import sys +import time +from dataclasses import dataclass, field +from typing import Any + +from voice_to_notes.ipc.messages import progress_message +from voice_to_notes.ipc.protocol import write_message +from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment +from voice_to_notes.services.transcribe import ( + SegmentResult, + TranscribeService, + TranscriptionResult, + WordResult, +) + + +@dataclass +class PipelineSegment: + """A transcript segment with speaker assignment.""" + + text: str + start_ms: int + end_ms: int + speaker: str | None + words: list[WordResult] = field(default_factory=list) + + +@dataclass +class PipelineResult: + """Full pipeline output combining transcription and diarization.""" + + segments: list[PipelineSegment] = field(default_factory=list) + language: str = "" + language_probability: float = 0.0 + duration_ms: int = 0 + speakers: list[str] = field(default_factory=list) + num_speakers: int = 0 + class PipelineService: - """Runs the full WhisperX-style pipeline: transcribe -> align -> diarize -> merge.""" + """Runs the full pipeline: transcribe -> diarize -> merge.""" - # TODO: Implement combined pipeline - # 1. faster-whisper transcription - # 2. wav2vec2 word-level alignment - # 3. pyannote diarization - # 4. Merge words with speaker segments - pass + def __init__(self) -> None: + self._transcribe_service = TranscribeService() + self._diarize_service = DiarizeService() + + def run( + self, + request_id: str, + file_path: str, + model_name: str = "base", + device: str = "cpu", + compute_type: str = "int8", + language: str | None = None, + num_speakers: int | None = None, + min_speakers: int | None = None, + max_speakers: int | None = None, + skip_diarization: bool = False, + ) -> PipelineResult: + """Run the full transcription + diarization pipeline. + + Args: + request_id: IPC request ID for progress messages. + file_path: Path to audio file. + model_name: Whisper model size. + device: 'cpu' or 'cuda'. + compute_type: Quantization type. + language: Language code or None for auto-detect. + num_speakers: Exact speaker count (if known). + min_speakers: Minimum expected speakers. + max_speakers: Maximum expected speakers. + skip_diarization: If True, only transcribe (no speaker ID). + """ + start_time = time.time() + + # Step 1: Transcribe + write_message( + progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...") + ) + + transcription = self._transcribe_service.transcribe( + request_id=request_id, + file_path=file_path, + model_name=model_name, + device=device, + compute_type=compute_type, + language=language, + ) + + if skip_diarization: + # Convert transcription directly without speaker labels + result = PipelineResult( + language=transcription.language, + language_probability=transcription.language_probability, + duration_ms=transcription.duration_ms, + ) + for seg in transcription.segments: + result.segments.append( + PipelineSegment( + text=seg.text, + start_ms=seg.start_ms, + end_ms=seg.end_ms, + speaker=None, + words=seg.words, + ) + ) + return result + + # Step 2: Diarize + write_message( + progress_message(request_id, 50, "pipeline", "Starting speaker diarization...") + ) + + diarization = self._diarize_service.diarize( + request_id=request_id, + file_path=file_path, + num_speakers=num_speakers, + min_speakers=min_speakers, + max_speakers=max_speakers, + ) + + # Step 3: Merge + write_message( + progress_message(request_id, 90, "pipeline", "Merging transcript with speakers...") + ) + + result = self._merge_results(transcription, diarization.speaker_segments) + result.speakers = diarization.speakers + result.num_speakers = diarization.num_speakers + + elapsed = time.time() - start_time + print( + f"[sidecar] Pipeline complete in {elapsed:.1f}s: " + f"{len(result.segments)} segments, {result.num_speakers} speakers", + file=sys.stderr, + flush=True, + ) + + write_message( + progress_message(request_id, 100, "done", "Pipeline complete") + ) + + return result + + def _merge_results( + self, + transcription: TranscriptionResult, + speaker_segments: list[SpeakerSegment], + ) -> PipelineResult: + """Merge transcription segments with speaker assignments. + + For each transcript segment, find the speaker who has the most + overlap with that segment's time range. + """ + result = PipelineResult( + language=transcription.language, + language_probability=transcription.language_probability, + duration_ms=transcription.duration_ms, + ) + + for seg in transcription.segments: + speaker = self._find_speaker_for_segment( + seg.start_ms, seg.end_ms, speaker_segments + ) + + # Also assign speakers to individual words + words_with_speaker = [] + for word in seg.words: + words_with_speaker.append(word) + + result.segments.append( + PipelineSegment( + text=seg.text, + start_ms=seg.start_ms, + end_ms=seg.end_ms, + speaker=speaker, + words=words_with_speaker, + ) + ) + + return result + + def _find_speaker_for_segment( + self, + start_ms: int, + end_ms: int, + speaker_segments: list[SpeakerSegment], + ) -> str | None: + """Find the speaker with the most overlap for a given time range.""" + best_speaker: str | None = None + best_overlap = 0 + + for ss in speaker_segments: + overlap_start = max(start_ms, ss.start_ms) + overlap_end = min(end_ms, ss.end_ms) + overlap = max(0, overlap_end - overlap_start) + + if overlap > best_overlap: + best_overlap = overlap + best_speaker = ss.speaker + + return best_speaker + + +def pipeline_result_to_payload(result: PipelineResult) -> dict[str, Any]: + """Convert PipelineResult to IPC payload dict.""" + return { + "segments": [ + { + "text": seg.text, + "start_ms": seg.start_ms, + "end_ms": seg.end_ms, + "speaker": seg.speaker, + "words": [ + { + "word": w.word, + "start_ms": w.start_ms, + "end_ms": w.end_ms, + "confidence": w.confidence, + } + for w in seg.words + ], + } + for seg in result.segments + ], + "language": result.language, + "language_probability": result.language_probability, + "duration_ms": result.duration_ms, + "speakers": result.speakers, + "num_speakers": result.num_speakers, + } diff --git a/src-tauri/src/commands/transcribe.rs b/src-tauri/src/commands/transcribe.rs index 9420cef..5b90b25 100644 --- a/src-tauri/src/commands/transcribe.rs +++ b/src-tauri/src/commands/transcribe.rs @@ -50,3 +50,55 @@ pub fn transcribe_file( Ok(response.payload) } + +/// Run the full transcription + diarization pipeline via the Python sidecar. +#[tauri::command] +pub fn run_pipeline( + file_path: String, + model: Option, + device: Option, + language: Option, + num_speakers: Option, + min_speakers: Option, + max_speakers: Option, + skip_diarization: Option, +) -> Result { + let python_path = std::env::current_dir() + .map_err(|e| e.to_string())? + .join("../python") + .canonicalize() + .map_err(|e| format!("Cannot find python directory: {e}"))?; + + let python_path_str = python_path.to_string_lossy().to_string(); + + let manager = SidecarManager::new(); + manager.start(&python_path_str)?; + + let request_id = uuid::Uuid::new_v4().to_string(); + let msg = IPCMessage::new( + &request_id, + "pipeline.start", + json!({ + "file": file_path, + "model": model.unwrap_or_else(|| "base".to_string()), + "device": device.unwrap_or_else(|| "cpu".to_string()), + "compute_type": "int8", + "language": language, + "num_speakers": num_speakers, + "min_speakers": min_speakers, + "max_speakers": max_speakers, + "skip_diarization": skip_diarization.unwrap_or(false), + }), + ); + + let response = manager.send_and_receive(&msg)?; + + if response.msg_type == "error" { + return Err(format!( + "Pipeline error: {}", + response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown") + )); + } + + Ok(response.payload) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index a463111..dfd73fb 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -4,7 +4,7 @@ pub mod sidecar; pub mod state; use commands::project::{create_project, get_project, list_projects}; -use commands::transcribe::transcribe_file; +use commands::transcribe::{run_pipeline, transcribe_file}; #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { @@ -16,6 +16,7 @@ pub fn run() { get_project, list_projects, transcribe_file, + run_pipeline, ]) .run(tauri::generate_context!()) .expect("error while running tauri application"); diff --git a/src/lib/components/SpeakerManager.svelte b/src/lib/components/SpeakerManager.svelte index 4ec599c..c5a925b 100644 --- a/src/lib/components/SpeakerManager.svelte +++ b/src/lib/components/SpeakerManager.svelte @@ -1,6 +1,67 @@ + +

Speakers

-

Speaker list with rename/color controls

+ {#if $speakers.length === 0} +

No speakers detected yet

+ {:else} +
    + {#each $speakers as speaker (speaker.id)} +
  • + + {#if editingSpeakerId === speaker.id} + finishRename(speaker.id)} + onkeydown={(e) => handleKeydown(e, speaker.id)} + /> + {:else} + + startRename(speaker)}> + {speaker.display_name || speaker.label} + + + {/if} +
  • + {/each} +
+

Double-click a name to rename

+ {/if}
diff --git a/src/lib/services/tauri-bridge.ts b/src/lib/services/tauri-bridge.ts index 3ca6e48..c13da24 100644 --- a/src/lib/services/tauri-bridge.ts +++ b/src/lib/services/tauri-bridge.ts @@ -38,3 +38,35 @@ export async function transcribeFile( ): Promise { return invoke('transcribe_file', { filePath, model, device, language }); } + +export interface PipelineResult extends TranscriptionResult { + segments: Array; + speakers: string[]; + num_speakers: number; +} + +export async function runPipeline( + filePath: string, + options?: { + model?: string; + device?: string; + language?: string; + numSpeakers?: number; + minSpeakers?: number; + maxSpeakers?: number; + skipDiarization?: boolean; + }, +): Promise { + return invoke('run_pipeline', { + filePath, + model: options?.model, + device: options?.device, + language: options?.language, + numSpeakers: options?.numSpeakers, + minSpeakers: options?.minSpeakers, + maxSpeakers: options?.maxSpeakers, + skipDiarization: options?.skipDiarization, + }); +} diff --git a/src/routes/+page.svelte b/src/routes/+page.svelte index d16de6a..b62b0e9 100644 --- a/src/routes/+page.svelte +++ b/src/routes/+page.svelte @@ -7,7 +7,7 @@ import AIChatPanel from '$lib/components/AIChatPanel.svelte'; import ProgressOverlay from '$lib/components/ProgressOverlay.svelte'; import { segments, speakers } from '$lib/stores/transcript'; - import type { Segment, Word } from '$lib/types/transcript'; + import type { Segment, Speaker } from '$lib/types/transcript'; let waveformPlayer: WaveformPlayer; let audioUrl = $state(''); @@ -16,6 +16,9 @@ let transcriptionStage = $state(''); let transcriptionMessage = $state(''); + // Speaker color palette for auto-assignment + const speakerColors = ['#e94560', '#4ecdc4', '#ffe66d', '#a8e6cf', '#ff8b94', '#c7ceea', '#ffd93d', '#6bcb77']; + function handleWordClick(timeMs: number) { waveformPlayer?.seekTo(timeMs); } @@ -32,11 +35,10 @@ if (!filePath) return; // Convert file path to URL for wavesurfer - // In Tauri, we can use convertFileSrc or asset protocol audioUrl = `asset://localhost/${encodeURIComponent(filePath)}`; waveformPlayer?.loadAudio(audioUrl); - // Start transcription + // Start pipeline (transcription + diarization) isTranscribing = true; transcriptionProgress = 0; transcriptionStage = 'Starting...'; @@ -47,6 +49,7 @@ text: string; start_ms: number; end_ms: number; + speaker: string | null; words: Array<{ word: string; start_ms: number; @@ -56,14 +59,29 @@ }>; language: string; duration_ms: number; - }>('transcribe_file', { filePath }); + speakers: string[]; + num_speakers: number; + }>('run_pipeline', { filePath }); + + // Create speaker entries from pipeline result + const newSpeakers: Speaker[] = (result.speakers || []).map((label, idx) => ({ + id: `speaker-${idx}`, + project_id: '', + label, + display_name: null, + color: speakerColors[idx % speakerColors.length], + })); + speakers.set(newSpeakers); + + // Build speaker label → id lookup + const speakerLookup = new Map(newSpeakers.map(s => [s.label, s.id])); // Convert result to our store format const newSegments: Segment[] = result.segments.map((seg, idx) => ({ id: `seg-${idx}`, project_id: '', media_file_id: '', - speaker_id: null, + speaker_id: seg.speaker ? (speakerLookup.get(seg.speaker) ?? null) : null, start_ms: seg.start_ms, end_ms: seg.end_ms, text: seg.text, @@ -85,8 +103,8 @@ segments.set(newSegments); } catch (err) { - console.error('Transcription failed:', err); - alert(`Transcription failed: ${err}`); + console.error('Pipeline failed:', err); + alert(`Pipeline failed: ${err}`); } finally { isTranscribing = false; }