diff --git a/python/voice_to_notes/ipc/handlers.py b/python/voice_to_notes/ipc/handlers.py index 4334f23..35a98b5 100644 --- a/python/voice_to_notes/ipc/handlers.py +++ b/python/voice_to_notes/ipc/handlers.py @@ -88,6 +88,57 @@ def make_diarize_handler() -> HandlerFunc: return handler +def make_diarize_download_handler() -> HandlerFunc: + """Create a handler that downloads/validates the diarization model.""" + + def handler(msg: IPCMessage) -> IPCMessage: + payload = msg.payload + hf_token = payload.get("hf_token") + + try: + from pyannote.audio import Pipeline + + print("[sidecar] Downloading diarization model...", file=sys.stderr, flush=True) + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + token=hf_token, + ) + print("[sidecar] Diarization model downloaded successfully", file=sys.stderr, flush=True) + return IPCMessage( + id=msg.id, + type="diarize.download.result", + payload={"ok": True}, + ) + except Exception as e: + error_msg = str(e) + # Make common errors more user-friendly + if "403" in error_msg and "gated" in error_msg.lower(): + # Extract which model needs access + if "segmentation" in error_msg: + error_msg = ( + "Access denied for pyannote/segmentation-3.0. " + "Please visit huggingface.co/pyannote/segmentation-3.0 " + "and accept the license agreement." + ) + elif "speaker-diarization" in error_msg: + error_msg = ( + "Access denied for pyannote/speaker-diarization-3.1. " + "Please visit huggingface.co/pyannote/speaker-diarization-3.1 " + "and accept the license agreement." + ) + else: + error_msg = ( + "Access denied. Please accept the license agreements at: " + "huggingface.co/pyannote/speaker-diarization-3.1 and " + "huggingface.co/pyannote/segmentation-3.0" + ) + elif "401" in error_msg: + error_msg = "Invalid token. Please check your HuggingFace token." + return error_message(msg.id, "download_error", error_msg) + + return handler + + def make_pipeline_handler() -> HandlerFunc: """Create a full pipeline handler (transcribe + diarize + merge).""" from voice_to_notes.services.pipeline import PipelineService, pipeline_result_to_payload diff --git a/python/voice_to_notes/main.py b/python/voice_to_notes/main.py index 77e9e7d..d72d1df 100644 --- a/python/voice_to_notes/main.py +++ b/python/voice_to_notes/main.py @@ -15,6 +15,7 @@ from voice_to_notes.ipc.handlers import ( # noqa: E402 HandlerRegistry, hardware_detect_handler, make_ai_chat_handler, + make_diarize_download_handler, make_diarize_handler, make_export_handler, make_pipeline_handler, @@ -32,6 +33,7 @@ def create_registry() -> HandlerRegistry: registry.register("transcribe.start", make_transcribe_handler()) registry.register("hardware.detect", hardware_detect_handler) registry.register("diarize.start", make_diarize_handler()) + registry.register("diarize.download", make_diarize_download_handler()) registry.register("pipeline.start", make_pipeline_handler()) registry.register("export.start", make_export_handler()) registry.register("ai.chat", make_ai_chat_handler()) diff --git a/src-tauri/src/commands/transcribe.rs b/src-tauri/src/commands/transcribe.rs index d4c90ea..0dc625d 100644 --- a/src-tauri/src/commands/transcribe.rs +++ b/src-tauri/src/commands/transcribe.rs @@ -40,6 +40,35 @@ pub fn transcribe_file( Ok(response.payload) } +/// Download and validate the diarization model via the Python sidecar. +#[tauri::command] +pub fn download_diarize_model( + hf_token: String, +) -> Result { + let manager = sidecar(); + manager.ensure_running()?; + + let request_id = uuid::Uuid::new_v4().to_string(); + let msg = IPCMessage::new( + &request_id, + "diarize.download", + json!({ + "hf_token": hf_token, + }), + ); + + let response = manager.send_and_receive(&msg)?; + + if response.msg_type == "error" { + return Ok(json!({ + "ok": false, + "error": response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown"), + })); + } + + Ok(json!({ "ok": true })) +} + /// Run the full transcription + diarization pipeline via the Python sidecar. #[tauri::command] pub fn run_pipeline( diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index e13b003..0f3e476 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -12,7 +12,7 @@ use commands::export::export_transcript; use commands::project::{create_project, get_project, list_projects}; use commands::settings::{load_settings, save_settings}; use commands::system::{get_data_dir, llama_list_models, llama_start, llama_status, llama_stop}; -use commands::transcribe::{run_pipeline, transcribe_file}; +use commands::transcribe::{download_diarize_model, run_pipeline, transcribe_file}; use state::AppState; #[cfg_attr(mobile, tauri::mobile_entry_point)] @@ -36,6 +36,7 @@ pub fn run() { list_projects, transcribe_file, run_pipeline, + download_diarize_model, export_transcript, ai_chat, ai_list_providers, diff --git a/src/lib/components/SettingsModal.svelte b/src/lib/components/SettingsModal.svelte index 1794071..85ed628 100644 --- a/src/lib/components/SettingsModal.svelte +++ b/src/lib/components/SettingsModal.svelte @@ -1,4 +1,6 @@