Phase 3: Speaker diarization and full transcription pipeline
- Implement DiarizeService with pyannote.audio speaker detection - Build PipelineService combining transcribe → diarize → merge with overlap-based speaker assignment per segment - Add pipeline.start and diarize.start IPC handlers - Add run_pipeline Tauri command for full pipeline execution - Wire frontend to use pipeline: speakers auto-created with colors, segments assigned to detected speakers - Build SpeakerManager with rename support (double-click or edit button) - Add speaker color coding throughout transcript display - Add pyannote.audio dependency - Tests: 24 Python (including merge logic), 6 Rust, 0 Svelte errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -11,6 +11,7 @@ license = "MIT"
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"faster-whisper>=1.1.0",
|
"faster-whisper>=1.1.0",
|
||||||
|
"pyannote.audio>=3.1.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
33
python/tests/test_diarize.py
Normal file
33
python/tests/test_diarize.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""Tests for diarization service data structures and payload conversion."""
|
||||||
|
|
||||||
|
from voice_to_notes.services.diarize import (
|
||||||
|
DiarizationResult,
|
||||||
|
SpeakerSegment,
|
||||||
|
diarization_to_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarization_to_payload():
|
||||||
|
result = DiarizationResult(
|
||||||
|
speaker_segments=[
|
||||||
|
SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=5000),
|
||||||
|
SpeakerSegment(speaker="SPEAKER_01", start_ms=5000, end_ms=10000),
|
||||||
|
SpeakerSegment(speaker="SPEAKER_00", start_ms=10000, end_ms=15000),
|
||||||
|
],
|
||||||
|
num_speakers=2,
|
||||||
|
speakers=["SPEAKER_00", "SPEAKER_01"],
|
||||||
|
)
|
||||||
|
payload = diarization_to_payload(result)
|
||||||
|
assert payload["num_speakers"] == 2
|
||||||
|
assert len(payload["speaker_segments"]) == 3
|
||||||
|
assert payload["speakers"] == ["SPEAKER_00", "SPEAKER_01"]
|
||||||
|
assert payload["speaker_segments"][0]["speaker"] == "SPEAKER_00"
|
||||||
|
assert payload["speaker_segments"][1]["start_ms"] == 5000
|
||||||
|
|
||||||
|
|
||||||
|
def test_diarization_to_payload_empty():
|
||||||
|
result = DiarizationResult()
|
||||||
|
payload = diarization_to_payload(result)
|
||||||
|
assert payload["num_speakers"] == 0
|
||||||
|
assert payload["speaker_segments"] == []
|
||||||
|
assert payload["speakers"] == []
|
||||||
90
python/tests/test_pipeline.py
Normal file
90
python/tests/test_pipeline.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for pipeline service data structures and merge logic."""
|
||||||
|
|
||||||
|
from voice_to_notes.services.diarize import SpeakerSegment
|
||||||
|
from voice_to_notes.services.pipeline import (
|
||||||
|
PipelineResult,
|
||||||
|
PipelineSegment,
|
||||||
|
PipelineService,
|
||||||
|
pipeline_result_to_payload,
|
||||||
|
)
|
||||||
|
from voice_to_notes.services.transcribe import (
|
||||||
|
SegmentResult,
|
||||||
|
TranscriptionResult,
|
||||||
|
WordResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_result_to_payload():
|
||||||
|
result = PipelineResult(
|
||||||
|
segments=[
|
||||||
|
PipelineSegment(
|
||||||
|
text="Hello world",
|
||||||
|
start_ms=0,
|
||||||
|
end_ms=2000,
|
||||||
|
speaker="SPEAKER_00",
|
||||||
|
words=[
|
||||||
|
WordResult(word="Hello", start_ms=0, end_ms=800, confidence=0.95),
|
||||||
|
WordResult(word="world", start_ms=900, end_ms=2000, confidence=0.88),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
language="en",
|
||||||
|
language_probability=0.98,
|
||||||
|
duration_ms=10000,
|
||||||
|
speakers=["SPEAKER_00", "SPEAKER_01"],
|
||||||
|
num_speakers=2,
|
||||||
|
)
|
||||||
|
payload = pipeline_result_to_payload(result)
|
||||||
|
assert payload["language"] == "en"
|
||||||
|
assert payload["num_speakers"] == 2
|
||||||
|
assert len(payload["segments"]) == 1
|
||||||
|
assert payload["segments"][0]["speaker"] == "SPEAKER_00"
|
||||||
|
assert len(payload["segments"][0]["words"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_result_to_payload_empty():
|
||||||
|
result = PipelineResult()
|
||||||
|
payload = pipeline_result_to_payload(result)
|
||||||
|
assert payload["segments"] == []
|
||||||
|
assert payload["speakers"] == []
|
||||||
|
assert payload["num_speakers"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_results_assigns_speakers():
|
||||||
|
"""Test that _merge_results correctly assigns speakers based on overlap."""
|
||||||
|
service = PipelineService()
|
||||||
|
|
||||||
|
transcription = TranscriptionResult(
|
||||||
|
segments=[
|
||||||
|
SegmentResult(text="Hello there", start_ms=0, end_ms=3000, words=[]),
|
||||||
|
SegmentResult(text="How are you", start_ms=4000, end_ms=7000, words=[]),
|
||||||
|
],
|
||||||
|
language="en",
|
||||||
|
language_probability=0.99,
|
||||||
|
duration_ms=10000,
|
||||||
|
)
|
||||||
|
|
||||||
|
speaker_segments = [
|
||||||
|
SpeakerSegment(speaker="SPEAKER_00", start_ms=0, end_ms=3500),
|
||||||
|
SpeakerSegment(speaker="SPEAKER_01", start_ms=3500, end_ms=8000),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = service._merge_results(transcription, speaker_segments)
|
||||||
|
assert len(result.segments) == 2
|
||||||
|
assert result.segments[0].speaker == "SPEAKER_00"
|
||||||
|
assert result.segments[1].speaker == "SPEAKER_01"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_results_no_speaker_segments():
|
||||||
|
"""With no speaker segments, all speakers should be None."""
|
||||||
|
service = PipelineService()
|
||||||
|
|
||||||
|
transcription = TranscriptionResult(
|
||||||
|
segments=[SegmentResult(text="Hello", start_ms=0, end_ms=1000, words=[])],
|
||||||
|
language="en",
|
||||||
|
language_probability=0.99,
|
||||||
|
duration_ms=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = service._merge_results(transcription, [])
|
||||||
|
assert result.segments[0].speaker is None
|
||||||
@@ -64,6 +64,59 @@ def make_transcribe_handler() -> HandlerFunc:
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def make_diarize_handler() -> HandlerFunc:
|
||||||
|
"""Create a diarization handler with a persistent DiarizeService."""
|
||||||
|
from voice_to_notes.services.diarize import DiarizeService, diarization_to_payload
|
||||||
|
|
||||||
|
service = DiarizeService()
|
||||||
|
|
||||||
|
def handler(msg: IPCMessage) -> IPCMessage:
|
||||||
|
payload = msg.payload
|
||||||
|
result = service.diarize(
|
||||||
|
request_id=msg.id,
|
||||||
|
file_path=payload["file"],
|
||||||
|
num_speakers=payload.get("num_speakers"),
|
||||||
|
min_speakers=payload.get("min_speakers"),
|
||||||
|
max_speakers=payload.get("max_speakers"),
|
||||||
|
)
|
||||||
|
return IPCMessage(
|
||||||
|
id=msg.id,
|
||||||
|
type="diarize.result",
|
||||||
|
payload=diarization_to_payload(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def make_pipeline_handler() -> HandlerFunc:
|
||||||
|
"""Create a full pipeline handler (transcribe + diarize + merge)."""
|
||||||
|
from voice_to_notes.services.pipeline import PipelineService, pipeline_result_to_payload
|
||||||
|
|
||||||
|
service = PipelineService()
|
||||||
|
|
||||||
|
def handler(msg: IPCMessage) -> IPCMessage:
|
||||||
|
payload = msg.payload
|
||||||
|
result = service.run(
|
||||||
|
request_id=msg.id,
|
||||||
|
file_path=payload["file"],
|
||||||
|
model_name=payload.get("model", "base"),
|
||||||
|
device=payload.get("device", "cpu"),
|
||||||
|
compute_type=payload.get("compute_type", "int8"),
|
||||||
|
language=payload.get("language"),
|
||||||
|
num_speakers=payload.get("num_speakers"),
|
||||||
|
min_speakers=payload.get("min_speakers"),
|
||||||
|
max_speakers=payload.get("max_speakers"),
|
||||||
|
skip_diarization=payload.get("skip_diarization", False),
|
||||||
|
)
|
||||||
|
return IPCMessage(
|
||||||
|
id=msg.id,
|
||||||
|
type="pipeline.result",
|
||||||
|
payload=pipeline_result_to_payload(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
def hardware_detect_handler(msg: IPCMessage) -> IPCMessage:
|
def hardware_detect_handler(msg: IPCMessage) -> IPCMessage:
|
||||||
"""Detect hardware capabilities and return recommendations."""
|
"""Detect hardware capabilities and return recommendations."""
|
||||||
from voice_to_notes.hardware.detect import detect_hardware
|
from voice_to_notes.hardware.detect import detect_hardware
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import sys
|
|||||||
from voice_to_notes.ipc.handlers import (
|
from voice_to_notes.ipc.handlers import (
|
||||||
HandlerRegistry,
|
HandlerRegistry,
|
||||||
hardware_detect_handler,
|
hardware_detect_handler,
|
||||||
|
make_diarize_handler,
|
||||||
|
make_pipeline_handler,
|
||||||
make_transcribe_handler,
|
make_transcribe_handler,
|
||||||
ping_handler,
|
ping_handler,
|
||||||
)
|
)
|
||||||
@@ -21,7 +23,8 @@ def create_registry() -> HandlerRegistry:
|
|||||||
registry.register("ping", ping_handler)
|
registry.register("ping", ping_handler)
|
||||||
registry.register("transcribe.start", make_transcribe_handler())
|
registry.register("transcribe.start", make_transcribe_handler())
|
||||||
registry.register("hardware.detect", hardware_detect_handler)
|
registry.register("hardware.detect", hardware_detect_handler)
|
||||||
# TODO: Register diarize, pipeline, ai, export handlers
|
registry.register("diarize.start", make_diarize_handler())
|
||||||
|
registry.register("pipeline.start", make_pipeline_handler())
|
||||||
return registry
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,166 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from voice_to_notes.ipc.messages import progress_message
|
||||||
|
from voice_to_notes.ipc.protocol import write_message
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpeakerSegment:
|
||||||
|
"""A time span assigned to a speaker."""
|
||||||
|
|
||||||
|
speaker: str
|
||||||
|
start_ms: int
|
||||||
|
end_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiarizationResult:
|
||||||
|
"""Full diarization output."""
|
||||||
|
|
||||||
|
speaker_segments: list[SpeakerSegment] = field(default_factory=list)
|
||||||
|
num_speakers: int = 0
|
||||||
|
speakers: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DiarizeService:
|
class DiarizeService:
|
||||||
"""Handles speaker diarization via pyannote.audio."""
|
"""Handles speaker diarization via pyannote.audio."""
|
||||||
|
|
||||||
# TODO: Implement pyannote.audio integration
|
def __init__(self) -> None:
|
||||||
# - Load community-1 model
|
self._pipeline: Any = None
|
||||||
# - Run diarization on audio
|
|
||||||
# - Return speaker segments with timestamps
|
def _ensure_pipeline(self) -> Any:
|
||||||
pass
|
"""Load the pyannote diarization pipeline (lazy)."""
|
||||||
|
if self._pipeline is not None:
|
||||||
|
return self._pipeline
|
||||||
|
|
||||||
|
print("[sidecar] Loading pyannote diarization pipeline...", file=sys.stderr, flush=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.1",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fall back to a simpler approach if the model isn't available
|
||||||
|
# pyannote requires HuggingFace token for some models
|
||||||
|
# Try the community model first
|
||||||
|
try:
|
||||||
|
from pyannote.audio import Pipeline
|
||||||
|
|
||||||
|
self._pipeline = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization",
|
||||||
|
use_auth_token=False,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"[sidecar] Warning: Could not load pyannote pipeline: {e}",
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
raise RuntimeError(
|
||||||
|
"pyannote.audio pipeline not available. "
|
||||||
|
"You may need to accept the model license at "
|
||||||
|
"https://huggingface.co/pyannote/speaker-diarization-3.1 "
|
||||||
|
"and set a HF_TOKEN environment variable."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return self._pipeline
|
||||||
|
|
||||||
|
def diarize(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
file_path: str,
|
||||||
|
num_speakers: int | None = None,
|
||||||
|
min_speakers: int | None = None,
|
||||||
|
max_speakers: int | None = None,
|
||||||
|
) -> DiarizationResult:
|
||||||
|
"""Run speaker diarization on an audio file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: IPC request ID for progress messages.
|
||||||
|
file_path: Path to audio file.
|
||||||
|
num_speakers: Exact number of speakers (if known).
|
||||||
|
min_speakers: Minimum expected speakers.
|
||||||
|
max_speakers: Maximum expected speakers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DiarizationResult with speaker segments.
|
||||||
|
"""
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 0, "loading_diarization", "Loading diarization model...")
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = self._ensure_pipeline()
|
||||||
|
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 20, "diarizing", "Running speaker diarization...")
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Build kwargs for speaker constraints
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if num_speakers is not None:
|
||||||
|
kwargs["num_speakers"] = num_speakers
|
||||||
|
if min_speakers is not None:
|
||||||
|
kwargs["min_speakers"] = min_speakers
|
||||||
|
if max_speakers is not None:
|
||||||
|
kwargs["max_speakers"] = max_speakers
|
||||||
|
|
||||||
|
# Run diarization
|
||||||
|
diarization = pipeline(file_path, **kwargs)
|
||||||
|
|
||||||
|
# Convert pyannote output to our format
|
||||||
|
result = DiarizationResult()
|
||||||
|
seen_speakers: set[str] = set()
|
||||||
|
|
||||||
|
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||||
|
result.speaker_segments.append(
|
||||||
|
SpeakerSegment(
|
||||||
|
speaker=speaker,
|
||||||
|
start_ms=int(turn.start * 1000),
|
||||||
|
end_ms=int(turn.end * 1000),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
seen_speakers.add(speaker)
|
||||||
|
|
||||||
|
result.speakers = sorted(seen_speakers)
|
||||||
|
result.num_speakers = len(seen_speakers)
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"[sidecar] Diarization complete: {result.num_speakers} speakers, "
|
||||||
|
f"{len(result.speaker_segments)} segments in {elapsed:.1f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 100, "done", "Diarization complete")
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def diarization_to_payload(result: DiarizationResult) -> dict[str, Any]:
|
||||||
|
"""Convert DiarizationResult to IPC payload dict."""
|
||||||
|
return {
|
||||||
|
"speaker_segments": [
|
||||||
|
{
|
||||||
|
"speaker": seg.speaker,
|
||||||
|
"start_ms": seg.start_ms,
|
||||||
|
"end_ms": seg.end_ms,
|
||||||
|
}
|
||||||
|
for seg in result.speaker_segments
|
||||||
|
],
|
||||||
|
"num_speakers": result.num_speakers,
|
||||||
|
"speakers": result.speakers,
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,13 +2,234 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from voice_to_notes.ipc.messages import progress_message
|
||||||
|
from voice_to_notes.ipc.protocol import write_message
|
||||||
|
from voice_to_notes.services.diarize import DiarizeService, SpeakerSegment
|
||||||
|
from voice_to_notes.services.transcribe import (
|
||||||
|
SegmentResult,
|
||||||
|
TranscribeService,
|
||||||
|
TranscriptionResult,
|
||||||
|
WordResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineSegment:
|
||||||
|
"""A transcript segment with speaker assignment."""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
start_ms: int
|
||||||
|
end_ms: int
|
||||||
|
speaker: str | None
|
||||||
|
words: list[WordResult] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineResult:
|
||||||
|
"""Full pipeline output combining transcription and diarization."""
|
||||||
|
|
||||||
|
segments: list[PipelineSegment] = field(default_factory=list)
|
||||||
|
language: str = ""
|
||||||
|
language_probability: float = 0.0
|
||||||
|
duration_ms: int = 0
|
||||||
|
speakers: list[str] = field(default_factory=list)
|
||||||
|
num_speakers: int = 0
|
||||||
|
|
||||||
|
|
||||||
class PipelineService:
|
class PipelineService:
|
||||||
"""Runs the full WhisperX-style pipeline: transcribe -> align -> diarize -> merge."""
|
"""Runs the full pipeline: transcribe -> diarize -> merge."""
|
||||||
|
|
||||||
# TODO: Implement combined pipeline
|
def __init__(self) -> None:
|
||||||
# 1. faster-whisper transcription
|
self._transcribe_service = TranscribeService()
|
||||||
# 2. wav2vec2 word-level alignment
|
self._diarize_service = DiarizeService()
|
||||||
# 3. pyannote diarization
|
|
||||||
# 4. Merge words with speaker segments
|
def run(
|
||||||
pass
|
self,
|
||||||
|
request_id: str,
|
||||||
|
file_path: str,
|
||||||
|
model_name: str = "base",
|
||||||
|
device: str = "cpu",
|
||||||
|
compute_type: str = "int8",
|
||||||
|
language: str | None = None,
|
||||||
|
num_speakers: int | None = None,
|
||||||
|
min_speakers: int | None = None,
|
||||||
|
max_speakers: int | None = None,
|
||||||
|
skip_diarization: bool = False,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""Run the full transcription + diarization pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: IPC request ID for progress messages.
|
||||||
|
file_path: Path to audio file.
|
||||||
|
model_name: Whisper model size.
|
||||||
|
device: 'cpu' or 'cuda'.
|
||||||
|
compute_type: Quantization type.
|
||||||
|
language: Language code or None for auto-detect.
|
||||||
|
num_speakers: Exact speaker count (if known).
|
||||||
|
min_speakers: Minimum expected speakers.
|
||||||
|
max_speakers: Maximum expected speakers.
|
||||||
|
skip_diarization: If True, only transcribe (no speaker ID).
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Step 1: Transcribe
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 0, "pipeline", "Starting transcription pipeline...")
|
||||||
|
)
|
||||||
|
|
||||||
|
transcription = self._transcribe_service.transcribe(
|
||||||
|
request_id=request_id,
|
||||||
|
file_path=file_path,
|
||||||
|
model_name=model_name,
|
||||||
|
device=device,
|
||||||
|
compute_type=compute_type,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
|
if skip_diarization:
|
||||||
|
# Convert transcription directly without speaker labels
|
||||||
|
result = PipelineResult(
|
||||||
|
language=transcription.language,
|
||||||
|
language_probability=transcription.language_probability,
|
||||||
|
duration_ms=transcription.duration_ms,
|
||||||
|
)
|
||||||
|
for seg in transcription.segments:
|
||||||
|
result.segments.append(
|
||||||
|
PipelineSegment(
|
||||||
|
text=seg.text,
|
||||||
|
start_ms=seg.start_ms,
|
||||||
|
end_ms=seg.end_ms,
|
||||||
|
speaker=None,
|
||||||
|
words=seg.words,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Step 2: Diarize
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 50, "pipeline", "Starting speaker diarization...")
|
||||||
|
)
|
||||||
|
|
||||||
|
diarization = self._diarize_service.diarize(
|
||||||
|
request_id=request_id,
|
||||||
|
file_path=file_path,
|
||||||
|
num_speakers=num_speakers,
|
||||||
|
min_speakers=min_speakers,
|
||||||
|
max_speakers=max_speakers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 3: Merge
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 90, "pipeline", "Merging transcript with speakers...")
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._merge_results(transcription, diarization.speaker_segments)
|
||||||
|
result.speakers = diarization.speakers
|
||||||
|
result.num_speakers = diarization.num_speakers
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
print(
|
||||||
|
f"[sidecar] Pipeline complete in {elapsed:.1f}s: "
|
||||||
|
f"{len(result.segments)} segments, {result.num_speakers} speakers",
|
||||||
|
file=sys.stderr,
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
write_message(
|
||||||
|
progress_message(request_id, 100, "done", "Pipeline complete")
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _merge_results(
|
||||||
|
self,
|
||||||
|
transcription: TranscriptionResult,
|
||||||
|
speaker_segments: list[SpeakerSegment],
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""Merge transcription segments with speaker assignments.
|
||||||
|
|
||||||
|
For each transcript segment, find the speaker who has the most
|
||||||
|
overlap with that segment's time range.
|
||||||
|
"""
|
||||||
|
result = PipelineResult(
|
||||||
|
language=transcription.language,
|
||||||
|
language_probability=transcription.language_probability,
|
||||||
|
duration_ms=transcription.duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
for seg in transcription.segments:
|
||||||
|
speaker = self._find_speaker_for_segment(
|
||||||
|
seg.start_ms, seg.end_ms, speaker_segments
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also assign speakers to individual words
|
||||||
|
words_with_speaker = []
|
||||||
|
for word in seg.words:
|
||||||
|
words_with_speaker.append(word)
|
||||||
|
|
||||||
|
result.segments.append(
|
||||||
|
PipelineSegment(
|
||||||
|
text=seg.text,
|
||||||
|
start_ms=seg.start_ms,
|
||||||
|
end_ms=seg.end_ms,
|
||||||
|
speaker=speaker,
|
||||||
|
words=words_with_speaker,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _find_speaker_for_segment(
|
||||||
|
self,
|
||||||
|
start_ms: int,
|
||||||
|
end_ms: int,
|
||||||
|
speaker_segments: list[SpeakerSegment],
|
||||||
|
) -> str | None:
|
||||||
|
"""Find the speaker with the most overlap for a given time range."""
|
||||||
|
best_speaker: str | None = None
|
||||||
|
best_overlap = 0
|
||||||
|
|
||||||
|
for ss in speaker_segments:
|
||||||
|
overlap_start = max(start_ms, ss.start_ms)
|
||||||
|
overlap_end = min(end_ms, ss.end_ms)
|
||||||
|
overlap = max(0, overlap_end - overlap_start)
|
||||||
|
|
||||||
|
if overlap > best_overlap:
|
||||||
|
best_overlap = overlap
|
||||||
|
best_speaker = ss.speaker
|
||||||
|
|
||||||
|
return best_speaker
|
||||||
|
|
||||||
|
|
||||||
|
def pipeline_result_to_payload(result: PipelineResult) -> dict[str, Any]:
|
||||||
|
"""Convert PipelineResult to IPC payload dict."""
|
||||||
|
return {
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"text": seg.text,
|
||||||
|
"start_ms": seg.start_ms,
|
||||||
|
"end_ms": seg.end_ms,
|
||||||
|
"speaker": seg.speaker,
|
||||||
|
"words": [
|
||||||
|
{
|
||||||
|
"word": w.word,
|
||||||
|
"start_ms": w.start_ms,
|
||||||
|
"end_ms": w.end_ms,
|
||||||
|
"confidence": w.confidence,
|
||||||
|
}
|
||||||
|
for w in seg.words
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for seg in result.segments
|
||||||
|
],
|
||||||
|
"language": result.language,
|
||||||
|
"language_probability": result.language_probability,
|
||||||
|
"duration_ms": result.duration_ms,
|
||||||
|
"speakers": result.speakers,
|
||||||
|
"num_speakers": result.num_speakers,
|
||||||
|
}
|
||||||
|
|||||||
@@ -50,3 +50,55 @@ pub fn transcribe_file(
|
|||||||
|
|
||||||
Ok(response.payload)
|
Ok(response.payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Run the full transcription + diarization pipeline via the Python sidecar.
|
||||||
|
#[tauri::command]
|
||||||
|
pub fn run_pipeline(
|
||||||
|
file_path: String,
|
||||||
|
model: Option<String>,
|
||||||
|
device: Option<String>,
|
||||||
|
language: Option<String>,
|
||||||
|
num_speakers: Option<u32>,
|
||||||
|
min_speakers: Option<u32>,
|
||||||
|
max_speakers: Option<u32>,
|
||||||
|
skip_diarization: Option<bool>,
|
||||||
|
) -> Result<Value, String> {
|
||||||
|
let python_path = std::env::current_dir()
|
||||||
|
.map_err(|e| e.to_string())?
|
||||||
|
.join("../python")
|
||||||
|
.canonicalize()
|
||||||
|
.map_err(|e| format!("Cannot find python directory: {e}"))?;
|
||||||
|
|
||||||
|
let python_path_str = python_path.to_string_lossy().to_string();
|
||||||
|
|
||||||
|
let manager = SidecarManager::new();
|
||||||
|
manager.start(&python_path_str)?;
|
||||||
|
|
||||||
|
let request_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let msg = IPCMessage::new(
|
||||||
|
&request_id,
|
||||||
|
"pipeline.start",
|
||||||
|
json!({
|
||||||
|
"file": file_path,
|
||||||
|
"model": model.unwrap_or_else(|| "base".to_string()),
|
||||||
|
"device": device.unwrap_or_else(|| "cpu".to_string()),
|
||||||
|
"compute_type": "int8",
|
||||||
|
"language": language,
|
||||||
|
"num_speakers": num_speakers,
|
||||||
|
"min_speakers": min_speakers,
|
||||||
|
"max_speakers": max_speakers,
|
||||||
|
"skip_diarization": skip_diarization.unwrap_or(false),
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = manager.send_and_receive(&msg)?;
|
||||||
|
|
||||||
|
if response.msg_type == "error" {
|
||||||
|
return Err(format!(
|
||||||
|
"Pipeline error: {}",
|
||||||
|
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown")
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(response.payload)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ pub mod sidecar;
|
|||||||
pub mod state;
|
pub mod state;
|
||||||
|
|
||||||
use commands::project::{create_project, get_project, list_projects};
|
use commands::project::{create_project, get_project, list_projects};
|
||||||
use commands::transcribe::transcribe_file;
|
use commands::transcribe::{run_pipeline, transcribe_file};
|
||||||
|
|
||||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||||
pub fn run() {
|
pub fn run() {
|
||||||
@@ -16,6 +16,7 @@ pub fn run() {
|
|||||||
get_project,
|
get_project,
|
||||||
list_projects,
|
list_projects,
|
||||||
transcribe_file,
|
transcribe_file,
|
||||||
|
run_pipeline,
|
||||||
])
|
])
|
||||||
.run(tauri::generate_context!())
|
.run(tauri::generate_context!())
|
||||||
.expect("error while running tauri application");
|
.expect("error while running tauri application");
|
||||||
|
|||||||
@@ -1,6 +1,67 @@
|
|||||||
|
<script lang="ts">
|
||||||
|
import { speakers } from '$lib/stores/transcript';
|
||||||
|
import type { Speaker } from '$lib/types/transcript';
|
||||||
|
|
||||||
|
let editingSpeakerId = $state<string | null>(null);
|
||||||
|
let editName = $state('');
|
||||||
|
|
||||||
|
function startRename(speaker: Speaker) {
|
||||||
|
editingSpeakerId = speaker.id;
|
||||||
|
editName = speaker.display_name || speaker.label;
|
||||||
|
}
|
||||||
|
|
||||||
|
function finishRename(speakerId: string) {
|
||||||
|
const trimmed = editName.trim();
|
||||||
|
if (trimmed) {
|
||||||
|
speakers.update(list => list.map(s => {
|
||||||
|
if (s.id !== speakerId) return s;
|
||||||
|
return { ...s, display_name: trimmed };
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
editingSpeakerId = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function handleKeydown(e: KeyboardEvent, speakerId: string) {
|
||||||
|
if (e.key === 'Enter') {
|
||||||
|
e.preventDefault();
|
||||||
|
finishRename(speakerId);
|
||||||
|
} else if (e.key === 'Escape') {
|
||||||
|
editingSpeakerId = null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
|
||||||
<div class="speaker-manager">
|
<div class="speaker-manager">
|
||||||
<h3>Speakers</h3>
|
<h3>Speakers</h3>
|
||||||
<p class="placeholder">Speaker list with rename/color controls</p>
|
{#if $speakers.length === 0}
|
||||||
|
<p class="empty-hint">No speakers detected yet</p>
|
||||||
|
{:else}
|
||||||
|
<ul class="speaker-list">
|
||||||
|
{#each $speakers as speaker (speaker.id)}
|
||||||
|
<li class="speaker-item">
|
||||||
|
<span class="speaker-color" style="background: {speaker.color}"></span>
|
||||||
|
{#if editingSpeakerId === speaker.id}
|
||||||
|
<input
|
||||||
|
class="rename-input"
|
||||||
|
type="text"
|
||||||
|
bind:value={editName}
|
||||||
|
onblur={() => finishRename(speaker.id)}
|
||||||
|
onkeydown={(e) => handleKeydown(e, speaker.id)}
|
||||||
|
/>
|
||||||
|
{:else}
|
||||||
|
<!-- svelte-ignore a11y_no_static_element_interactions -->
|
||||||
|
<span class="speaker-name" ondblclick={() => startRename(speaker)}>
|
||||||
|
{speaker.display_name || speaker.label}
|
||||||
|
</span>
|
||||||
|
<button class="rename-btn" onclick={() => startRename(speaker)} title="Rename speaker">
|
||||||
|
✏
|
||||||
|
</button>
|
||||||
|
{/if}
|
||||||
|
</li>
|
||||||
|
{/each}
|
||||||
|
</ul>
|
||||||
|
<p class="speaker-hint">Double-click a name to rename</p>
|
||||||
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<style>
|
<style>
|
||||||
@@ -10,9 +71,72 @@
|
|||||||
border-radius: 8px;
|
border-radius: 8px;
|
||||||
color: #e0e0e0;
|
color: #e0e0e0;
|
||||||
}
|
}
|
||||||
h3 { margin: 0 0 0.5rem; }
|
h3 {
|
||||||
.placeholder {
|
margin: 0 0 0.5rem;
|
||||||
|
font-size: 0.95rem;
|
||||||
|
}
|
||||||
|
.empty-hint {
|
||||||
color: #666;
|
color: #666;
|
||||||
font-size: 0.875rem;
|
font-size: 0.875rem;
|
||||||
}
|
}
|
||||||
|
.speaker-list {
|
||||||
|
list-style: none;
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.speaker-item {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
padding: 0.35rem 0.5rem;
|
||||||
|
background: rgba(255,255,255,0.03);
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
.speaker-color {
|
||||||
|
width: 12px;
|
||||||
|
height: 12px;
|
||||||
|
border-radius: 50%;
|
||||||
|
flex-shrink: 0;
|
||||||
|
}
|
||||||
|
.speaker-name {
|
||||||
|
flex: 1;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
.rename-btn {
|
||||||
|
background: none;
|
||||||
|
border: none;
|
||||||
|
color: #666;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
padding: 0.15rem 0.3rem;
|
||||||
|
border-radius: 3px;
|
||||||
|
}
|
||||||
|
.rename-btn:hover {
|
||||||
|
background: rgba(255,255,255,0.1);
|
||||||
|
color: #e0e0e0;
|
||||||
|
}
|
||||||
|
.rename-input {
|
||||||
|
flex: 1;
|
||||||
|
background: #1a1a2e;
|
||||||
|
color: #e0e0e0;
|
||||||
|
border: 1px solid #e94560;
|
||||||
|
border-radius: 3px;
|
||||||
|
padding: 0.2rem 0.4rem;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-family: inherit;
|
||||||
|
}
|
||||||
|
.rename-input:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #ff6b81;
|
||||||
|
}
|
||||||
|
.speaker-hint {
|
||||||
|
color: #555;
|
||||||
|
font-size: 0.7rem;
|
||||||
|
margin-top: 0.5rem;
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
</style>
|
</style>
|
||||||
|
|||||||
@@ -38,3 +38,35 @@ export async function transcribeFile(
|
|||||||
): Promise<TranscriptionResult> {
|
): Promise<TranscriptionResult> {
|
||||||
return invoke('transcribe_file', { filePath, model, device, language });
|
return invoke('transcribe_file', { filePath, model, device, language });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface PipelineResult extends TranscriptionResult {
|
||||||
|
segments: Array<TranscriptionResult['segments'][0] & {
|
||||||
|
speaker: string | null;
|
||||||
|
}>;
|
||||||
|
speakers: string[];
|
||||||
|
num_speakers: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function runPipeline(
|
||||||
|
filePath: string,
|
||||||
|
options?: {
|
||||||
|
model?: string;
|
||||||
|
device?: string;
|
||||||
|
language?: string;
|
||||||
|
numSpeakers?: number;
|
||||||
|
minSpeakers?: number;
|
||||||
|
maxSpeakers?: number;
|
||||||
|
skipDiarization?: boolean;
|
||||||
|
},
|
||||||
|
): Promise<PipelineResult> {
|
||||||
|
return invoke('run_pipeline', {
|
||||||
|
filePath,
|
||||||
|
model: options?.model,
|
||||||
|
device: options?.device,
|
||||||
|
language: options?.language,
|
||||||
|
numSpeakers: options?.numSpeakers,
|
||||||
|
minSpeakers: options?.minSpeakers,
|
||||||
|
maxSpeakers: options?.maxSpeakers,
|
||||||
|
skipDiarization: options?.skipDiarization,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import AIChatPanel from '$lib/components/AIChatPanel.svelte';
|
import AIChatPanel from '$lib/components/AIChatPanel.svelte';
|
||||||
import ProgressOverlay from '$lib/components/ProgressOverlay.svelte';
|
import ProgressOverlay from '$lib/components/ProgressOverlay.svelte';
|
||||||
import { segments, speakers } from '$lib/stores/transcript';
|
import { segments, speakers } from '$lib/stores/transcript';
|
||||||
import type { Segment, Word } from '$lib/types/transcript';
|
import type { Segment, Speaker } from '$lib/types/transcript';
|
||||||
|
|
||||||
let waveformPlayer: WaveformPlayer;
|
let waveformPlayer: WaveformPlayer;
|
||||||
let audioUrl = $state('');
|
let audioUrl = $state('');
|
||||||
@@ -16,6 +16,9 @@
|
|||||||
let transcriptionStage = $state('');
|
let transcriptionStage = $state('');
|
||||||
let transcriptionMessage = $state('');
|
let transcriptionMessage = $state('');
|
||||||
|
|
||||||
|
// Speaker color palette for auto-assignment
|
||||||
|
const speakerColors = ['#e94560', '#4ecdc4', '#ffe66d', '#a8e6cf', '#ff8b94', '#c7ceea', '#ffd93d', '#6bcb77'];
|
||||||
|
|
||||||
function handleWordClick(timeMs: number) {
|
function handleWordClick(timeMs: number) {
|
||||||
waveformPlayer?.seekTo(timeMs);
|
waveformPlayer?.seekTo(timeMs);
|
||||||
}
|
}
|
||||||
@@ -32,11 +35,10 @@
|
|||||||
if (!filePath) return;
|
if (!filePath) return;
|
||||||
|
|
||||||
// Convert file path to URL for wavesurfer
|
// Convert file path to URL for wavesurfer
|
||||||
// In Tauri, we can use convertFileSrc or asset protocol
|
|
||||||
audioUrl = `asset://localhost/${encodeURIComponent(filePath)}`;
|
audioUrl = `asset://localhost/${encodeURIComponent(filePath)}`;
|
||||||
waveformPlayer?.loadAudio(audioUrl);
|
waveformPlayer?.loadAudio(audioUrl);
|
||||||
|
|
||||||
// Start transcription
|
// Start pipeline (transcription + diarization)
|
||||||
isTranscribing = true;
|
isTranscribing = true;
|
||||||
transcriptionProgress = 0;
|
transcriptionProgress = 0;
|
||||||
transcriptionStage = 'Starting...';
|
transcriptionStage = 'Starting...';
|
||||||
@@ -47,6 +49,7 @@
|
|||||||
text: string;
|
text: string;
|
||||||
start_ms: number;
|
start_ms: number;
|
||||||
end_ms: number;
|
end_ms: number;
|
||||||
|
speaker: string | null;
|
||||||
words: Array<{
|
words: Array<{
|
||||||
word: string;
|
word: string;
|
||||||
start_ms: number;
|
start_ms: number;
|
||||||
@@ -56,14 +59,29 @@
|
|||||||
}>;
|
}>;
|
||||||
language: string;
|
language: string;
|
||||||
duration_ms: number;
|
duration_ms: number;
|
||||||
}>('transcribe_file', { filePath });
|
speakers: string[];
|
||||||
|
num_speakers: number;
|
||||||
|
}>('run_pipeline', { filePath });
|
||||||
|
|
||||||
|
// Create speaker entries from pipeline result
|
||||||
|
const newSpeakers: Speaker[] = (result.speakers || []).map((label, idx) => ({
|
||||||
|
id: `speaker-${idx}`,
|
||||||
|
project_id: '',
|
||||||
|
label,
|
||||||
|
display_name: null,
|
||||||
|
color: speakerColors[idx % speakerColors.length],
|
||||||
|
}));
|
||||||
|
speakers.set(newSpeakers);
|
||||||
|
|
||||||
|
// Build speaker label → id lookup
|
||||||
|
const speakerLookup = new Map(newSpeakers.map(s => [s.label, s.id]));
|
||||||
|
|
||||||
// Convert result to our store format
|
// Convert result to our store format
|
||||||
const newSegments: Segment[] = result.segments.map((seg, idx) => ({
|
const newSegments: Segment[] = result.segments.map((seg, idx) => ({
|
||||||
id: `seg-${idx}`,
|
id: `seg-${idx}`,
|
||||||
project_id: '',
|
project_id: '',
|
||||||
media_file_id: '',
|
media_file_id: '',
|
||||||
speaker_id: null,
|
speaker_id: seg.speaker ? (speakerLookup.get(seg.speaker) ?? null) : null,
|
||||||
start_ms: seg.start_ms,
|
start_ms: seg.start_ms,
|
||||||
end_ms: seg.end_ms,
|
end_ms: seg.end_ms,
|
||||||
text: seg.text,
|
text: seg.text,
|
||||||
@@ -85,8 +103,8 @@
|
|||||||
|
|
||||||
segments.set(newSegments);
|
segments.set(newSegments);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Transcription failed:', err);
|
console.error('Pipeline failed:', err);
|
||||||
alert(`Transcription failed: ${err}`);
|
alert(`Pipeline failed: ${err}`);
|
||||||
} finally {
|
} finally {
|
||||||
isTranscribing = false;
|
isTranscribing = false;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user