347 lines
11 KiB
Python
347 lines
11 KiB
Python
|
|
"""
|
||
|
|
Remote Transcription Client
|
||
|
|
|
||
|
|
Handles streaming audio to a remote transcription service and receiving transcriptions.
|
||
|
|
Provides fallback to local transcription if the remote service is unavailable.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
import numpy as np
|
||
|
|
from datetime import datetime
|
||
|
|
from threading import Thread, Lock
|
||
|
|
from typing import Optional, Callable
|
||
|
|
from queue import Queue, Empty
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
class RemoteTranscriptionClient:
|
||
|
|
"""
|
||
|
|
Client for remote transcription service.
|
||
|
|
|
||
|
|
Streams audio to a remote server and receives transcriptions.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
server_url: str,
|
||
|
|
api_key: str,
|
||
|
|
on_transcription: Optional[Callable[[str, bool], None]] = None,
|
||
|
|
on_error: Optional[Callable[[str], None]] = None,
|
||
|
|
on_connection_change: Optional[Callable[[bool], None]] = None,
|
||
|
|
sample_rate: int = 16000
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize remote transcription client.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
server_url: WebSocket URL of the transcription service
|
||
|
|
api_key: API key for authentication
|
||
|
|
on_transcription: Callback for transcriptions (text, is_preview)
|
||
|
|
on_error: Callback for errors
|
||
|
|
on_connection_change: Callback for connection status changes
|
||
|
|
sample_rate: Audio sample rate
|
||
|
|
"""
|
||
|
|
self.server_url = server_url
|
||
|
|
self.api_key = api_key
|
||
|
|
self.sample_rate = sample_rate
|
||
|
|
self.on_transcription = on_transcription
|
||
|
|
self.on_error = on_error
|
||
|
|
self.on_connection_change = on_connection_change
|
||
|
|
|
||
|
|
self.websocket = None
|
||
|
|
self.is_connected = False
|
||
|
|
self.is_authenticated = False
|
||
|
|
self.is_running = False
|
||
|
|
|
||
|
|
self.audio_queue: Queue = Queue()
|
||
|
|
self.send_thread: Optional[Thread] = None
|
||
|
|
self.receive_thread: Optional[Thread] = None
|
||
|
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
||
|
|
|
||
|
|
self._lock = Lock()
|
||
|
|
|
||
|
|
async def _connect(self):
|
||
|
|
"""Establish WebSocket connection and authenticate."""
|
||
|
|
try:
|
||
|
|
import websockets
|
||
|
|
|
||
|
|
logger.info(f"Connecting to {self.server_url}")
|
||
|
|
self.websocket = await websockets.connect(
|
||
|
|
self.server_url,
|
||
|
|
ping_interval=30,
|
||
|
|
ping_timeout=10
|
||
|
|
)
|
||
|
|
|
||
|
|
# Authenticate
|
||
|
|
auth_message = {
|
||
|
|
"type": "auth",
|
||
|
|
"api_key": self.api_key
|
||
|
|
}
|
||
|
|
await self.websocket.send(json.dumps(auth_message))
|
||
|
|
|
||
|
|
# Wait for auth response
|
||
|
|
response = await asyncio.wait_for(
|
||
|
|
self.websocket.recv(),
|
||
|
|
timeout=10.0
|
||
|
|
)
|
||
|
|
auth_result = json.loads(response)
|
||
|
|
|
||
|
|
if auth_result.get("type") == "auth_result" and auth_result.get("success"):
|
||
|
|
self.is_connected = True
|
||
|
|
self.is_authenticated = True
|
||
|
|
logger.info("Connected and authenticated to remote transcription service")
|
||
|
|
if self.on_connection_change:
|
||
|
|
self.on_connection_change(True)
|
||
|
|
return True
|
||
|
|
else:
|
||
|
|
error_msg = auth_result.get("message", "Authentication failed")
|
||
|
|
logger.error(f"Authentication failed: {error_msg}")
|
||
|
|
if self.on_error:
|
||
|
|
self.on_error(f"Authentication failed: {error_msg}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Connection failed: {e}")
|
||
|
|
if self.on_error:
|
||
|
|
self.on_error(f"Connection failed: {e}")
|
||
|
|
return False
|
||
|
|
|
||
|
|
async def _send_loop(self):
|
||
|
|
"""Send audio chunks from the queue."""
|
||
|
|
while self.is_running and self.websocket:
|
||
|
|
try:
|
||
|
|
# Get audio from queue with timeout
|
||
|
|
try:
|
||
|
|
audio_data = self.audio_queue.get(timeout=0.1)
|
||
|
|
except Empty:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if audio_data is None:
|
||
|
|
continue
|
||
|
|
|
||
|
|
# Encode audio as base64
|
||
|
|
audio_bytes = audio_data.astype(np.float32).tobytes()
|
||
|
|
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
|
||
|
|
|
||
|
|
# Send to server
|
||
|
|
message = {
|
||
|
|
"type": "audio",
|
||
|
|
"data": audio_b64,
|
||
|
|
"sample_rate": self.sample_rate
|
||
|
|
}
|
||
|
|
await self.websocket.send(json.dumps(message))
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
if self.is_running:
|
||
|
|
logger.error(f"Send error: {e}")
|
||
|
|
break
|
||
|
|
|
||
|
|
async def _receive_loop(self):
|
||
|
|
"""Receive transcriptions from the server."""
|
||
|
|
while self.is_running and self.websocket:
|
||
|
|
try:
|
||
|
|
message = await asyncio.wait_for(
|
||
|
|
self.websocket.recv(),
|
||
|
|
timeout=1.0
|
||
|
|
)
|
||
|
|
data = json.loads(message)
|
||
|
|
msg_type = data.get("type", "")
|
||
|
|
|
||
|
|
if msg_type == "transcription":
|
||
|
|
text = data.get("text", "")
|
||
|
|
is_preview = data.get("is_preview", False)
|
||
|
|
if text and self.on_transcription:
|
||
|
|
self.on_transcription(text, is_preview)
|
||
|
|
|
||
|
|
elif msg_type == "error":
|
||
|
|
error_msg = data.get("message", "Unknown error")
|
||
|
|
logger.error(f"Server error: {error_msg}")
|
||
|
|
if self.on_error:
|
||
|
|
self.on_error(error_msg)
|
||
|
|
|
||
|
|
elif msg_type == "pong":
|
||
|
|
pass # Keep-alive response
|
||
|
|
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
continue
|
||
|
|
except Exception as e:
|
||
|
|
if self.is_running:
|
||
|
|
logger.error(f"Receive error: {e}")
|
||
|
|
break
|
||
|
|
|
||
|
|
# Connection lost
|
||
|
|
self.is_connected = False
|
||
|
|
self.is_authenticated = False
|
||
|
|
if self.on_connection_change:
|
||
|
|
self.on_connection_change(False)
|
||
|
|
|
||
|
|
def _run_async(self):
|
||
|
|
"""Run the async event loop in a thread."""
|
||
|
|
self.loop = asyncio.new_event_loop()
|
||
|
|
asyncio.set_event_loop(self.loop)
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Connect
|
||
|
|
connected = self.loop.run_until_complete(self._connect())
|
||
|
|
if not connected:
|
||
|
|
return
|
||
|
|
|
||
|
|
# Run send and receive loops
|
||
|
|
tasks = [
|
||
|
|
self._send_loop(),
|
||
|
|
self._receive_loop()
|
||
|
|
]
|
||
|
|
self.loop.run_until_complete(asyncio.gather(*tasks))
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Async loop error: {e}")
|
||
|
|
finally:
|
||
|
|
if self.websocket:
|
||
|
|
try:
|
||
|
|
self.loop.run_until_complete(self.websocket.close())
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
self.loop.close()
|
||
|
|
|
||
|
|
def start(self):
|
||
|
|
"""Start the remote transcription client."""
|
||
|
|
with self._lock:
|
||
|
|
if self.is_running:
|
||
|
|
return
|
||
|
|
|
||
|
|
self.is_running = True
|
||
|
|
|
||
|
|
# Start async loop in background thread
|
||
|
|
self.send_thread = Thread(target=self._run_async, daemon=True)
|
||
|
|
self.send_thread.start()
|
||
|
|
|
||
|
|
def stop(self):
|
||
|
|
"""Stop the remote transcription client."""
|
||
|
|
with self._lock:
|
||
|
|
self.is_running = False
|
||
|
|
|
||
|
|
# Signal end to server
|
||
|
|
if self.websocket and self.loop:
|
||
|
|
try:
|
||
|
|
asyncio.run_coroutine_threadsafe(
|
||
|
|
self.websocket.send(json.dumps({"type": "end"})),
|
||
|
|
self.loop
|
||
|
|
)
|
||
|
|
except:
|
||
|
|
pass
|
||
|
|
|
||
|
|
self.is_connected = False
|
||
|
|
self.is_authenticated = False
|
||
|
|
|
||
|
|
def send_audio(self, audio_data: np.ndarray):
|
||
|
|
"""
|
||
|
|
Send audio data for transcription.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio_data: Audio data as numpy array (float32, mono, sample_rate)
|
||
|
|
"""
|
||
|
|
if self.is_connected and self.is_authenticated:
|
||
|
|
self.audio_queue.put(audio_data)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def connected(self) -> bool:
|
||
|
|
"""Check if connected and authenticated."""
|
||
|
|
return self.is_connected and self.is_authenticated
|
||
|
|
|
||
|
|
|
||
|
|
class RemoteTranscriptionManager:
|
||
|
|
"""
|
||
|
|
Manages remote transcription with fallback to local processing.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
server_url: str,
|
||
|
|
api_key: str,
|
||
|
|
local_engine=None,
|
||
|
|
on_transcription: Optional[Callable] = None,
|
||
|
|
on_preview: Optional[Callable] = None
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Initialize the remote transcription manager.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
server_url: Remote transcription service URL
|
||
|
|
api_key: API key for authentication
|
||
|
|
local_engine: Local transcription engine for fallback
|
||
|
|
on_transcription: Callback for final transcriptions
|
||
|
|
on_preview: Callback for preview transcriptions
|
||
|
|
"""
|
||
|
|
self.server_url = server_url
|
||
|
|
self.api_key = api_key
|
||
|
|
self.local_engine = local_engine
|
||
|
|
self.on_transcription = on_transcription
|
||
|
|
self.on_preview = on_preview
|
||
|
|
|
||
|
|
self.client: Optional[RemoteTranscriptionClient] = None
|
||
|
|
self.use_remote = True
|
||
|
|
self.is_running = False
|
||
|
|
|
||
|
|
def _handle_transcription(self, text: str, is_preview: bool):
|
||
|
|
"""Handle transcription from remote service."""
|
||
|
|
if is_preview:
|
||
|
|
if self.on_preview:
|
||
|
|
self.on_preview(text)
|
||
|
|
else:
|
||
|
|
if self.on_transcription:
|
||
|
|
self.on_transcription(text)
|
||
|
|
|
||
|
|
def _handle_error(self, error: str):
|
||
|
|
"""Handle error from remote service."""
|
||
|
|
logger.error(f"Remote transcription error: {error}")
|
||
|
|
# Could switch to local fallback here
|
||
|
|
|
||
|
|
def _handle_connection_change(self, connected: bool):
|
||
|
|
"""Handle connection status change."""
|
||
|
|
if connected:
|
||
|
|
logger.info("Remote transcription connected")
|
||
|
|
else:
|
||
|
|
logger.warning("Remote transcription disconnected")
|
||
|
|
# Could switch to local fallback here
|
||
|
|
|
||
|
|
def start(self):
|
||
|
|
"""Start remote transcription."""
|
||
|
|
if self.is_running:
|
||
|
|
return
|
||
|
|
|
||
|
|
self.is_running = True
|
||
|
|
|
||
|
|
if self.use_remote and self.server_url and self.api_key:
|
||
|
|
self.client = RemoteTranscriptionClient(
|
||
|
|
server_url=self.server_url,
|
||
|
|
api_key=self.api_key,
|
||
|
|
on_transcription=self._handle_transcription,
|
||
|
|
on_error=self._handle_error,
|
||
|
|
on_connection_change=self._handle_connection_change
|
||
|
|
)
|
||
|
|
self.client.start()
|
||
|
|
|
||
|
|
def stop(self):
|
||
|
|
"""Stop remote transcription."""
|
||
|
|
self.is_running = False
|
||
|
|
if self.client:
|
||
|
|
self.client.stop()
|
||
|
|
self.client = None
|
||
|
|
|
||
|
|
def send_audio(self, audio_data: np.ndarray):
|
||
|
|
"""Send audio for transcription."""
|
||
|
|
if self.client and self.client.connected:
|
||
|
|
self.client.send_audio(audio_data)
|
||
|
|
elif self.local_engine:
|
||
|
|
# Fallback to local processing
|
||
|
|
pass # Local engine handles its own audio capture
|
||
|
|
|
||
|
|
@property
|
||
|
|
def is_connected(self) -> bool:
|
||
|
|
"""Check if remote service is connected."""
|
||
|
|
return self.client is not None and self.client.connected
|