Files
local-transcription/client/deepgram_transcription.py

529 lines
19 KiB
Python
Raw Permalink Normal View History

"""Deepgram-based transcription engine using WebSocket streaming.
Supports two modes:
- Managed mode: connects to a proxy server that handles Deepgram credentials
- BYOK mode: connects directly to the Deepgram API with a user-provided key
Implements the same duck-type interface as RealtimeTranscriptionEngine so
MainWindow can use it as a drop-in replacement.
"""
import asyncio
import json
import logging
import numpy as np
import threading
from datetime import datetime
from queue import Queue, Empty
from typing import Optional, Callable
from client.transcription_engine_realtime import TranscriptionResult
logger = logging.getLogger(__name__)
class DeepgramTranscriptionEngine:
"""
Transcription engine that streams audio to Deepgram via WebSocket.
In managed mode the connection goes through a proxy at
``wss://<server>/ws/transcribe`` which handles authentication and
Deepgram credentials. In BYOK (bring-your-own-key) mode the
connection goes directly to the Deepgram API.
"""
# ------------------------------------------------------------------ #
# Construction / configuration
# ------------------------------------------------------------------ #
def __init__(self, config, user_name: str = "User", input_device_index: Optional[int] = None):
"""
Initialise the engine from a :class:`client.config.Config` object.
Args:
config: Application ``Config`` instance.
user_name: Display name attached to transcriptions.
input_device_index: Index of the audio input device to use
(``None`` for the system default).
"""
self.config = config
self.user_name = user_name
self.input_device_index = input_device_index
# Mode: 'managed' (proxy) or 'byok' (direct Deepgram)
self.mode: str = config.get("remote.mode", "managed")
# Managed-mode settings
self.server_url: str = config.get("remote.server_url", "")
self.auth_token: str = config.get("remote.auth_token", "")
# BYOK-mode settings
self.byok_api_key: str = config.get("remote.byok_api_key", "")
# Deepgram model / language (used in both modes)
self.deepgram_model: str = config.get("remote.deepgram_model", "nova-2")
self.language: str = config.get("remote.language", "en-US")
# Audio parameters
self.sample_rate: int = 16000
self.channels: int = 1
self.blocksize: int = 4096
# Callbacks
self.realtime_callback: Optional[Callable[[TranscriptionResult], None]] = None
self.final_callback: Optional[Callable[[TranscriptionResult], None]] = None
self._on_error: Optional[Callable[[str], None]] = None
self._on_credits_low: Optional[Callable[[int], None]] = None
# Internal state
self._is_initialized: bool = False
self._is_recording: bool = False
self._stop_event: threading.Event = threading.Event()
self._audio_queue: Queue = Queue()
# Asyncio event loop running in a daemon thread
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread: Optional[threading.Thread] = None
# WebSocket handle (set inside the async context)
self._ws = None
# sounddevice InputStream
self._stream = None
# ------------------------------------------------------------------ #
# Callback setters
# ------------------------------------------------------------------ #
def set_callbacks(
self,
realtime_callback: Optional[Callable[[TranscriptionResult], None]] = None,
final_callback: Optional[Callable[[TranscriptionResult], None]] = None,
):
"""Set transcription result callbacks (matches RealtimeTranscriptionEngine API)."""
self.realtime_callback = realtime_callback
self.final_callback = final_callback
def set_error_callback(self, fn: Optional[Callable[[str], None]]):
"""Set a callback invoked on errors. ``fn`` receives a string message."""
self._on_error = fn
def set_credits_low_callback(self, fn: Optional[Callable[[int], None]]):
"""Set a callback for low-credit warnings. ``fn`` receives seconds remaining."""
self._on_credits_low = fn
# ------------------------------------------------------------------ #
# Public interface (duck-typed with RealtimeTranscriptionEngine)
# ------------------------------------------------------------------ #
def initialize(self) -> bool:
"""Validate configuration and mark the engine as ready.
Returns ``True`` when the engine is ready to start recording.
"""
if self._is_initialized:
return True
if self.mode == "managed":
if not self.server_url:
logger.error("Managed mode requires a server URL (remote.server_url)")
return False
if not self.auth_token:
logger.error("Managed mode requires an auth token (remote.auth_token)")
return False
elif self.mode == "byok":
if not self.byok_api_key:
logger.error("BYOK mode requires an API key (remote.byok_api_key)")
return False
else:
logger.error("Unknown remote mode: %s (expected 'managed' or 'byok')", self.mode)
return False
self._is_initialized = True
logger.info("DeepgramTranscriptionEngine initialised in %s mode", self.mode)
return True
def start_recording(self) -> bool:
"""Open the audio stream and connect the WebSocket.
Returns ``True`` on success.
"""
if not self._is_initialized:
logger.error("Engine not initialised -- call initialize() first")
return False
if self._is_recording:
return True
self._stop_event.clear()
self._is_recording = True
# Start the asyncio event-loop thread (handles WS send/receive)
self._thread = threading.Thread(target=self._run_event_loop, daemon=True)
self._thread.start()
# Start the audio capture stream
try:
self._start_audio_stream()
except Exception as exc:
logger.error("Failed to open audio stream: %s", exc)
self._is_recording = False
self._stop_event.set()
return False
logger.info("Recording started")
return True
def stop_recording(self):
"""Stop audio capture and close the WebSocket."""
if not self._is_recording:
return
self._is_recording = False
self._stop_event.set()
# Stop audio stream
self._stop_audio_stream()
# Close WebSocket from outside the event-loop thread
if self._ws is not None and self._loop is not None and not self._loop.is_closed():
asyncio.run_coroutine_threadsafe(self._close_ws(), self._loop)
# Wait for the thread to finish
if self._thread is not None:
self._thread.join(timeout=5)
self._thread = None
logger.info("Recording stopped")
def stop(self):
"""Full shutdown -- stop recording and release all resources."""
self.stop_recording()
self._is_initialized = False
logger.info("DeepgramTranscriptionEngine shut down")
def is_ready(self) -> bool:
"""Return ``True`` if the engine has been successfully initialised."""
return self._is_initialized
# ------------------------------------------------------------------ #
# Audio capture (sounddevice)
# ------------------------------------------------------------------ #
def _start_audio_stream(self):
"""Open a ``sounddevice.InputStream`` that feeds the audio queue."""
import sounddevice as sd
def _audio_callback(indata, frames, time_info, status): # noqa: ARG001
if status:
logger.warning("Audio stream status: %s", status)
if self._is_recording:
# float32 -> int16 PCM bytes
pcm = (indata * 32767).astype(np.int16).tobytes()
self._audio_queue.put(pcm)
self._stream = sd.InputStream(
samplerate=self.sample_rate,
blocksize=self.blocksize,
channels=self.channels,
dtype="float32",
device=self.input_device_index,
callback=_audio_callback,
)
self._stream.start()
def _stop_audio_stream(self):
"""Close the audio input stream."""
if self._stream is not None:
try:
self._stream.stop()
self._stream.close()
except Exception as exc:
logger.debug("Error closing audio stream: %s", exc)
finally:
self._stream = None
# ------------------------------------------------------------------ #
# Asyncio event-loop (runs in daemon thread)
# ------------------------------------------------------------------ #
def _run_event_loop(self):
"""Entry point for the daemon thread -- runs the async event loop."""
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
try:
self._loop.run_until_complete(self._ws_lifecycle())
except Exception as exc:
logger.error("Event-loop error: %s", exc)
finally:
try:
self._loop.run_until_complete(self._loop.shutdown_asyncgens())
except Exception:
pass
self._loop.close()
self._loop = None
async def _ws_lifecycle(self):
"""Connect, authenticate (if managed), then run send/receive loops."""
import websockets
try:
ws_url, extra_headers = self._build_ws_url_and_headers()
logger.info("Connecting to %s", ws_url)
self._ws = await websockets.connect(
ws_url,
additional_headers=extra_headers,
ping_interval=20,
ping_timeout=10,
)
# Managed mode: send auth message and wait for ready
if self.mode == "managed":
if not await self._managed_handshake():
return
# Run send and receive concurrently
await asyncio.gather(
self._send_loop(),
self._receive_loop(),
)
except asyncio.CancelledError:
pass
except Exception as exc:
msg = f"WebSocket error: {exc}"
logger.error(msg)
if self._on_error:
self._on_error(msg)
finally:
await self._close_ws()
def _build_ws_url_and_headers(self):
"""Return ``(url, headers)`` depending on the current mode."""
if self.mode == "managed":
# Ensure the server URL uses wss:// and append the path
url = self.server_url.rstrip("/")
if not url.startswith("ws://") and not url.startswith("wss://"):
url = f"wss://{url}"
url = f"{url}/ws/transcribe"
return url, {}
# BYOK -- connect directly to Deepgram
params = (
f"model={self.deepgram_model}"
f"&language={self.language}"
"&interim_results=true"
"&encoding=linear16"
f"&sample_rate={self.sample_rate}"
f"&channels={self.channels}"
)
url = f"wss://api.deepgram.com/v1/listen?{params}"
headers = {"Authorization": f"Token {self.byok_api_key}"}
return url, headers
# -- managed-mode handshake ---------------------------------------- #
async def _managed_handshake(self) -> bool:
"""Send auth message and wait for ``ready`` (managed mode).
Returns ``True`` on success.
"""
auth_msg = {
"type": "auth",
"token": self.auth_token,
"config": {
"model": self.deepgram_model,
"language": self.language,
"sample_rate": self.sample_rate,
"channels": self.channels,
"encoding": "linear16",
"interim_results": True,
},
}
await self._ws.send(json.dumps(auth_msg))
try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=15)
data = json.loads(raw)
if data.get("type") == "ready":
logger.info("Managed proxy is ready")
return True
if data.get("type") == "error":
err = data.get("message", "unknown error")
logger.error("Auth error from proxy: %s", err)
if self._on_error:
self._on_error(f"Proxy auth error: {err}")
return False
logger.warning("Unexpected handshake message: %s", data)
return False
except asyncio.TimeoutError:
logger.error("Timed out waiting for proxy ready message")
if self._on_error:
self._on_error("Timed out waiting for proxy ready message")
return False
# -- send loop ----------------------------------------------------- #
async def _send_loop(self):
"""Drain the audio queue and push raw PCM bytes over the WebSocket."""
while not self._stop_event.is_set():
try:
pcm_bytes = self._audio_queue.get(timeout=0.1)
except Empty:
continue
try:
await self._ws.send(pcm_bytes)
except Exception as exc:
if not self._stop_event.is_set():
logger.error("Send error: %s", exc)
break
# -- receive loop -------------------------------------------------- #
async def _receive_loop(self):
"""Listen for messages from the WebSocket and dispatch them."""
while not self._stop_event.is_set():
try:
raw = await asyncio.wait_for(self._ws.recv(), timeout=1.0)
except asyncio.TimeoutError:
continue
except Exception as exc:
if not self._stop_event.is_set():
logger.error("Receive error: %s", exc)
break
try:
data = json.loads(raw)
except (json.JSONDecodeError, TypeError):
logger.debug("Non-JSON message received, ignoring")
continue
if self.mode == "managed":
self._handle_managed_message(data)
else:
self._handle_byok_message(data)
# ------------------------------------------------------------------ #
# Message handlers
# ------------------------------------------------------------------ #
def _handle_managed_message(self, data: dict):
"""Process a message from the managed proxy."""
msg_type = data.get("type", "")
if msg_type == "transcript":
text = data.get("text", "")
is_final = data.get("is_final", False)
if text.strip():
result = TranscriptionResult(
text=text,
is_final=is_final,
timestamp=datetime.now(),
user_name=self.user_name,
)
if is_final:
if self.final_callback:
self.final_callback(result)
else:
if self.realtime_callback:
self.realtime_callback(result)
elif msg_type == "credits_low":
seconds_remaining = data.get("seconds_remaining", 0)
logger.warning("Credits low -- %d seconds remaining", seconds_remaining)
if self._on_credits_low:
self._on_credits_low(int(seconds_remaining))
elif msg_type == "error":
code = data.get("code", "")
message = data.get("message", "Unknown error")
logger.error("Proxy error [%s]: %s", code, message)
if self._on_error:
self._on_error(f"[{code}] {message}" if code else message)
elif msg_type == "session_end":
seconds_used = data.get("seconds_used", 0)
logger.info("Session ended -- %d seconds used", seconds_used)
elif msg_type == "ready":
# May arrive again after reconnects; safe to ignore.
logger.debug("Received ready message (already connected)")
else:
logger.debug("Unhandled managed message type: %s", msg_type)
def _handle_byok_message(self, data: dict):
"""Process a message received directly from the Deepgram API."""
msg_type = data.get("type", "")
if msg_type == "Results":
channel = data.get("channel", {})
alternatives = channel.get("alternatives", [])
if not alternatives:
return
transcript = alternatives[0].get("transcript", "")
is_final = data.get("is_final", False)
if transcript.strip():
result = TranscriptionResult(
text=transcript,
is_final=is_final,
timestamp=datetime.now(),
user_name=self.user_name,
)
if is_final:
if self.final_callback:
self.final_callback(result)
else:
if self.realtime_callback:
self.realtime_callback(result)
elif msg_type == "Metadata":
logger.debug("Deepgram metadata: %s", data)
elif msg_type == "UtteranceEnd":
logger.debug("Deepgram utterance end")
else:
logger.debug("Unhandled Deepgram message type: %s", msg_type)
# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
async def _close_ws(self):
"""Close the WebSocket connection if open."""
if self._ws is not None:
try:
await self._ws.close()
except Exception:
pass
self._ws = None
def set_user_name(self, user_name: str):
"""Update the user name attached to future transcriptions."""
self.user_name = user_name
def is_recording_active(self) -> bool:
"""Return ``True`` if audio is currently being captured."""
return self._is_recording
def __repr__(self) -> str:
return (
f"DeepgramTranscriptionEngine(mode={self.mode}, "
f"recording={self._is_recording})"
)
def __del__(self):
"""Best-effort cleanup."""
try:
self.stop()
except Exception:
pass