Files
local-transcription/client/remote_transcription.py

347 lines
11 KiB
Python
Raw Permalink Normal View History

"""
Remote Transcription Client
Handles streaming audio to a remote transcription service and receiving transcriptions.
Provides fallback to local transcription if the remote service is unavailable.
"""
import asyncio
import base64
import json
import logging
import numpy as np
from datetime import datetime
from threading import Thread, Lock
from typing import Optional, Callable
from queue import Queue, Empty
logger = logging.getLogger(__name__)
class RemoteTranscriptionClient:
"""
Client for remote transcription service.
Streams audio to a remote server and receives transcriptions.
"""
def __init__(
self,
server_url: str,
api_key: str,
on_transcription: Optional[Callable[[str, bool], None]] = None,
on_error: Optional[Callable[[str], None]] = None,
on_connection_change: Optional[Callable[[bool], None]] = None,
sample_rate: int = 16000
):
"""
Initialize remote transcription client.
Args:
server_url: WebSocket URL of the transcription service
api_key: API key for authentication
on_transcription: Callback for transcriptions (text, is_preview)
on_error: Callback for errors
on_connection_change: Callback for connection status changes
sample_rate: Audio sample rate
"""
self.server_url = server_url
self.api_key = api_key
self.sample_rate = sample_rate
self.on_transcription = on_transcription
self.on_error = on_error
self.on_connection_change = on_connection_change
self.websocket = None
self.is_connected = False
self.is_authenticated = False
self.is_running = False
self.audio_queue: Queue = Queue()
self.send_thread: Optional[Thread] = None
self.receive_thread: Optional[Thread] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None
self._lock = Lock()
async def _connect(self):
"""Establish WebSocket connection and authenticate."""
try:
import websockets
logger.info(f"Connecting to {self.server_url}")
self.websocket = await websockets.connect(
self.server_url,
ping_interval=30,
ping_timeout=10
)
# Authenticate
auth_message = {
"type": "auth",
"api_key": self.api_key
}
await self.websocket.send(json.dumps(auth_message))
# Wait for auth response
response = await asyncio.wait_for(
self.websocket.recv(),
timeout=10.0
)
auth_result = json.loads(response)
if auth_result.get("type") == "auth_result" and auth_result.get("success"):
self.is_connected = True
self.is_authenticated = True
logger.info("Connected and authenticated to remote transcription service")
if self.on_connection_change:
self.on_connection_change(True)
return True
else:
error_msg = auth_result.get("message", "Authentication failed")
logger.error(f"Authentication failed: {error_msg}")
if self.on_error:
self.on_error(f"Authentication failed: {error_msg}")
return False
except Exception as e:
logger.error(f"Connection failed: {e}")
if self.on_error:
self.on_error(f"Connection failed: {e}")
return False
async def _send_loop(self):
"""Send audio chunks from the queue."""
while self.is_running and self.websocket:
try:
# Get audio from queue with timeout
try:
audio_data = self.audio_queue.get(timeout=0.1)
except Empty:
continue
if audio_data is None:
continue
# Encode audio as base64
audio_bytes = audio_data.astype(np.float32).tobytes()
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
# Send to server
message = {
"type": "audio",
"data": audio_b64,
"sample_rate": self.sample_rate
}
await self.websocket.send(json.dumps(message))
except Exception as e:
if self.is_running:
logger.error(f"Send error: {e}")
break
async def _receive_loop(self):
"""Receive transcriptions from the server."""
while self.is_running and self.websocket:
try:
message = await asyncio.wait_for(
self.websocket.recv(),
timeout=1.0
)
data = json.loads(message)
msg_type = data.get("type", "")
if msg_type == "transcription":
text = data.get("text", "")
is_preview = data.get("is_preview", False)
if text and self.on_transcription:
self.on_transcription(text, is_preview)
elif msg_type == "error":
error_msg = data.get("message", "Unknown error")
logger.error(f"Server error: {error_msg}")
if self.on_error:
self.on_error(error_msg)
elif msg_type == "pong":
pass # Keep-alive response
except asyncio.TimeoutError:
continue
except Exception as e:
if self.is_running:
logger.error(f"Receive error: {e}")
break
# Connection lost
self.is_connected = False
self.is_authenticated = False
if self.on_connection_change:
self.on_connection_change(False)
def _run_async(self):
"""Run the async event loop in a thread."""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
try:
# Connect
connected = self.loop.run_until_complete(self._connect())
if not connected:
return
# Run send and receive loops
tasks = [
self._send_loop(),
self._receive_loop()
]
self.loop.run_until_complete(asyncio.gather(*tasks))
except Exception as e:
logger.error(f"Async loop error: {e}")
finally:
if self.websocket:
try:
self.loop.run_until_complete(self.websocket.close())
except:
pass
self.loop.close()
def start(self):
"""Start the remote transcription client."""
with self._lock:
if self.is_running:
return
self.is_running = True
# Start async loop in background thread
self.send_thread = Thread(target=self._run_async, daemon=True)
self.send_thread.start()
def stop(self):
"""Stop the remote transcription client."""
with self._lock:
self.is_running = False
# Signal end to server
if self.websocket and self.loop:
try:
asyncio.run_coroutine_threadsafe(
self.websocket.send(json.dumps({"type": "end"})),
self.loop
)
except:
pass
self.is_connected = False
self.is_authenticated = False
def send_audio(self, audio_data: np.ndarray):
"""
Send audio data for transcription.
Args:
audio_data: Audio data as numpy array (float32, mono, sample_rate)
"""
if self.is_connected and self.is_authenticated:
self.audio_queue.put(audio_data)
@property
def connected(self) -> bool:
"""Check if connected and authenticated."""
return self.is_connected and self.is_authenticated
class RemoteTranscriptionManager:
"""
Manages remote transcription with fallback to local processing.
"""
def __init__(
self,
server_url: str,
api_key: str,
local_engine=None,
on_transcription: Optional[Callable] = None,
on_preview: Optional[Callable] = None
):
"""
Initialize the remote transcription manager.
Args:
server_url: Remote transcription service URL
api_key: API key for authentication
local_engine: Local transcription engine for fallback
on_transcription: Callback for final transcriptions
on_preview: Callback for preview transcriptions
"""
self.server_url = server_url
self.api_key = api_key
self.local_engine = local_engine
self.on_transcription = on_transcription
self.on_preview = on_preview
self.client: Optional[RemoteTranscriptionClient] = None
self.use_remote = True
self.is_running = False
def _handle_transcription(self, text: str, is_preview: bool):
"""Handle transcription from remote service."""
if is_preview:
if self.on_preview:
self.on_preview(text)
else:
if self.on_transcription:
self.on_transcription(text)
def _handle_error(self, error: str):
"""Handle error from remote service."""
logger.error(f"Remote transcription error: {error}")
# Could switch to local fallback here
def _handle_connection_change(self, connected: bool):
"""Handle connection status change."""
if connected:
logger.info("Remote transcription connected")
else:
logger.warning("Remote transcription disconnected")
# Could switch to local fallback here
def start(self):
"""Start remote transcription."""
if self.is_running:
return
self.is_running = True
if self.use_remote and self.server_url and self.api_key:
self.client = RemoteTranscriptionClient(
server_url=self.server_url,
api_key=self.api_key,
on_transcription=self._handle_transcription,
on_error=self._handle_error,
on_connection_change=self._handle_connection_change
)
self.client.start()
def stop(self):
"""Stop remote transcription."""
self.is_running = False
if self.client:
self.client.stop()
self.client = None
def send_audio(self, audio_data: np.ndarray):
"""Send audio for transcription."""
if self.client and self.client.connected:
self.client.send_audio(audio_data)
elif self.local_engine:
# Fallback to local processing
pass # Local engine handles its own audio capture
@property
def is_connected(self) -> bool:
"""Check if remote service is connected."""
return self.client is not None and self.client.connected