""" Remote Transcription Service A standalone FastAPI WebSocket server that accepts audio streams and returns transcriptions. Designed to run on a GPU-equipped server for offloading transcription processing. Usage: python server.py [--host HOST] [--port PORT] [--model MODEL] Environment variables: TRANSCRIPTION_API_KEY: Required API key for authentication TRANSCRIPTION_MODEL: Whisper model to use (default: base.en) """ import asyncio import argparse import os import sys import json import base64 import logging from datetime import datetime from pathlib import Path from typing import Optional, Dict, Set from threading import Thread, Lock import numpy as np from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException, Depends from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import uvicorn # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # API Key authentication API_KEYS: Set[str] = set() def load_api_keys(): """Load API keys from environment variable.""" global API_KEYS keys_env = os.environ.get('TRANSCRIPTION_API_KEYS', '') if keys_env: API_KEYS = set(key.strip() for key in keys_env.split(',') if key.strip()) # Also support single key single_key = os.environ.get('TRANSCRIPTION_API_KEY', '') if single_key: API_KEYS.add(single_key) if not API_KEYS: logger.warning("No API keys configured. Set TRANSCRIPTION_API_KEY or TRANSCRIPTION_API_KEYS environment variable.") logger.warning("Service will accept all connections (INSECURE for production).") def verify_api_key(api_key: str) -> bool: """Verify if the API key is valid.""" if not API_KEYS: return True # No authentication if no keys configured return api_key in API_KEYS app = FastAPI( title="Remote Transcription Service", description="GPU-accelerated speech-to-text transcription service", version="1.0.0" ) # Enable CORS for all origins (configure appropriately for production) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class TranscriptionEngine: """Manages the transcription engine with thread-safe access.""" def __init__(self, model: str = "base.en", device: str = "auto"): self.model_name = model self.device = device self.recorder = None self.lock = Lock() self.is_initialized = False def initialize(self): """Initialize the transcription engine.""" if self.is_initialized: return True try: from RealtimeSTT import AudioToTextRecorder # Determine device if self.device == "auto": import torch if torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" logger.info(f"Initializing transcription engine with model={self.model_name}, device={self.device}") # Create recorder with minimal configuration # We'll feed audio directly, not capture from microphone self.recorder = AudioToTextRecorder( model=self.model_name, language="en", device=self.device, compute_type="default", input_device_index=None, # No mic capture silero_sensitivity=0.4, webrtc_sensitivity=3, post_speech_silence_duration=0.3, min_length_of_recording=0.5, enable_realtime_transcription=True, realtime_model_type="tiny.en", ) self.is_initialized = True logger.info("Transcription engine initialized successfully") return True except Exception as e: logger.error(f"Failed to initialize transcription engine: {e}") return False def transcribe(self, audio_data: np.ndarray, sample_rate: int = 16000) -> Optional[str]: """ Transcribe audio data. Args: audio_data: Audio data as numpy array sample_rate: Sample rate of the audio Returns: Transcribed text or None if failed """ with self.lock: if not self.is_initialized: return None try: # Use faster-whisper directly for one-shot transcription from faster_whisper import WhisperModel if not hasattr(self, '_whisper_model'): self._whisper_model = WhisperModel( self.model_name, device=self.device, compute_type="default" ) # Transcribe segments, info = self._whisper_model.transcribe( audio_data, beam_size=5, language="en" ) # Combine segments text = " ".join(segment.text for segment in segments) return text.strip() except Exception as e: logger.error(f"Transcription error: {e}") return None # Global transcription engine engine: Optional[TranscriptionEngine] = None class ClientConnection: """Represents an active client connection.""" def __init__(self, websocket: WebSocket, client_id: str): self.websocket = websocket self.client_id = client_id self.audio_buffer = [] self.sample_rate = 16000 self.connected_at = datetime.now() # Active connections active_connections: Dict[str, ClientConnection] = {} @app.on_event("startup") async def startup_event(): """Initialize service on startup.""" load_api_keys() global engine model = os.environ.get('TRANSCRIPTION_MODEL', 'base.en') engine = TranscriptionEngine(model=model) # Initialize in background thread to not block startup def init_engine(): engine.initialize() Thread(target=init_engine, daemon=True).start() logger.info("Remote Transcription Service started") @app.get("/") async def root(): """Health check endpoint.""" return { "service": "Remote Transcription Service", "status": "running", "model": engine.model_name if engine else "not loaded", "device": engine.device if engine else "unknown", "active_connections": len(active_connections) } @app.get("/health") async def health(): """Detailed health check.""" return { "status": "healthy" if engine and engine.is_initialized else "initializing", "model": engine.model_name if engine else None, "device": engine.device if engine else None, "initialized": engine.is_initialized if engine else False, "connections": len(active_connections) } @app.websocket("/ws/transcribe") async def websocket_transcribe(websocket: WebSocket): """ WebSocket endpoint for audio transcription. Protocol: 1. Client sends: {"type": "auth", "api_key": "your-key"} 2. Server responds: {"type": "auth_result", "success": true/false} 3. Client sends audio chunks: {"type": "audio", "data": base64_audio, "sample_rate": 16000} 4. Server responds with transcription: {"type": "transcription", "text": "...", "is_preview": false} 5. Client can send: {"type": "end"} to close connection """ await websocket.accept() client_id = f"client_{id(websocket)}_{datetime.now().timestamp()}" authenticated = False logger.info(f"New WebSocket connection: {client_id}") try: while True: data = await websocket.receive_text() message = json.loads(data) msg_type = message.get("type", "") if msg_type == "auth": # Authenticate client api_key = message.get("api_key", "") if verify_api_key(api_key): authenticated = True active_connections[client_id] = ClientConnection(websocket, client_id) await websocket.send_json({ "type": "auth_result", "success": True, "message": "Authentication successful" }) logger.info(f"Client {client_id} authenticated") else: await websocket.send_json({ "type": "auth_result", "success": False, "message": "Invalid API key" }) logger.warning(f"Client {client_id} failed authentication") await websocket.close(code=4001, reason="Invalid API key") return elif msg_type == "audio": if not authenticated: await websocket.send_json({ "type": "error", "message": "Not authenticated" }) continue # Decode audio data audio_b64 = message.get("data", "") sample_rate = message.get("sample_rate", 16000) if audio_b64: try: audio_bytes = base64.b64decode(audio_b64) audio_data = np.frombuffer(audio_bytes, dtype=np.float32) # Transcribe if engine and engine.is_initialized: text = engine.transcribe(audio_data, sample_rate) if text: await websocket.send_json({ "type": "transcription", "text": text, "is_preview": False, "timestamp": datetime.now().isoformat() }) else: await websocket.send_json({ "type": "error", "message": "Transcription engine not ready" }) except Exception as e: logger.error(f"Audio processing error: {e}") await websocket.send_json({ "type": "error", "message": f"Audio processing error: {str(e)}" }) elif msg_type == "end": logger.info(f"Client {client_id} requested disconnect") break elif msg_type == "ping": await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: logger.info(f"Client {client_id} disconnected") except Exception as e: logger.error(f"WebSocket error for {client_id}: {e}") finally: if client_id in active_connections: del active_connections[client_id] def main(): """Main entry point.""" parser = argparse.ArgumentParser(description="Remote Transcription Service") parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") parser.add_argument("--port", type=int, default=8765, help="Port to bind to") parser.add_argument("--model", default="base.en", help="Whisper model to use") args = parser.parse_args() # Set model from command line os.environ.setdefault('TRANSCRIPTION_MODEL', args.model) logger.info(f"Starting Remote Transcription Service on {args.host}:{args.port}") logger.info(f"Model: {args.model}") uvicorn.run( app, host=args.host, port=args.port, log_level="info" ) if __name__ == "__main__": main()