"""FastAPI control API server for the headless transcription backend. Extends the existing OBS display server with REST endpoints and a control WebSocket channel so that a Tauri (or any other) frontend can drive the application. """ import asyncio import json from datetime import datetime from typing import List, Optional from fastapi import FastAPI, WebSocket, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from backend.app_controller import AppController # ── Request / Response Models ────────────────────────────────────── class ConfigUpdate(BaseModel): """Batch config update payload. Keys use dot-notation.""" settings: dict # e.g. {"user.name": "Alice", "transcription.model": "small.en"} class LoginRequest(BaseModel): email: str password: str server_url: str class RegisterRequest(BaseModel): email: str password: str server_url: str class SkipVersionRequest(BaseModel): version: str class SaveFileRequest(BaseModel): path: str text: str # ── API Server ───────────────────────────────────────────────────── class APIServer: """Wraps AppController with a FastAPI application exposing control endpoints.""" def __init__(self, controller: AppController): self.controller = controller self.control_connections: List[WebSocket] = [] self.app = FastAPI(title="Local Transcription API", version="1.0.0") # Allow Tauri webview origin self.app.add_middleware( CORSMiddleware, allow_origins=["*"], # Tauri uses tauri://localhost or https://tauri.localhost allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) self._setup_routes() self._wire_controller_callbacks() def _wire_controller_callbacks(self): """Wire AppController callbacks to broadcast over /ws/control.""" original_state_cb = self.controller.on_state_changed def on_state_changed(state: str, message: str): # Isolate the upstream callback so a failure there (e.g. a # broken stdout pipe in main_headless) cannot propagate into # _set_state and tear down engine init / reload_engine / # apply_settings request handling. if original_state_cb: try: original_state_cb(state, message) except Exception: pass self._broadcast_control({"type": "state_changed", "state": state, "message": message}) self.controller.on_state_changed = on_state_changed def on_transcription(data: dict): self._broadcast_control({"type": "transcription", **data}) self.controller.on_transcription = on_transcription def on_preview(data: dict): self._broadcast_control({"type": "preview", **data}) self.controller.on_preview = on_preview def on_error(msg: str): self._broadcast_control({"type": "error", "message": msg}) self.controller.on_error = on_error def on_credits_low(seconds: int): self._broadcast_control({"type": "credits_low", "seconds_remaining": seconds}) self.controller.on_credits_low = on_credits_low def set_event_loop(self, loop: asyncio.AbstractEventLoop): """Set the event loop used for broadcasting (call from uvicorn startup).""" self._event_loop = loop def _broadcast_control(self, data: dict): """Send a message to all connected /ws/control clients.""" if not self.control_connections: return loop = getattr(self, '_event_loop', None) if loop is None: return message = json.dumps(data) disconnected = [] for ws in self.control_connections: try: asyncio.run_coroutine_threadsafe( ws.send_text(message), loop, ) except Exception: disconnected.append(ws) for ws in disconnected: self.control_connections.remove(ws) def _setup_routes(self): """Register all API routes.""" app = self.app ctrl = self.controller @app.on_event("startup") async def on_startup(): self.set_event_loop(asyncio.get_event_loop()) # ── Status ───────────────────────────────────────────── @app.get("/api/status") async def get_status(): return ctrl.get_status() @app.get("/api/version") async def get_version(): from version import __version__ return {"version": __version__} # ── Transcription Control ────────────────────────────── @app.post("/api/start") async def start_transcription(): import asyncio # Run in thread pool to avoid blocking the event loop # (start_recording can block up to 15s waiting for Deepgram WS) loop = asyncio.get_event_loop() success, message = await loop.run_in_executor( None, ctrl.start_transcription ) if not success: raise HTTPException(status_code=400, detail=message) return {"status": "ok", "message": message} @app.post("/api/stop") async def stop_transcription(): import asyncio loop = asyncio.get_event_loop() success, message = await loop.run_in_executor( None, ctrl.stop_transcription ) if not success: raise HTTPException(status_code=400, detail=message) return {"status": "ok", "message": message} @app.post("/api/clear") async def clear_transcriptions(): count = ctrl.clear_transcriptions() return {"status": "ok", "cleared": count} @app.get("/api/transcriptions") async def get_transcriptions(): show_timestamps = ctrl.config.get('display.show_timestamps', True) return { "count": len(ctrl.transcriptions), "text": ctrl.get_transcriptions_text(include_timestamps=show_timestamps), "items": [ { "text": r.text, "user_name": r.user_name, "timestamp": r.timestamp.strftime("%H:%M:%S") if r.timestamp else None, } for r in ctrl.transcriptions ], } @app.post("/api/save-file") async def save_file(req: SaveFileRequest): """Save text to a file (used by Tauri frontend after dialog).""" from pathlib import Path try: Path(req.path).write_text(req.text, encoding="utf-8") return {"status": "ok", "path": req.path} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # ── Configuration ────────────────────────────────────── @app.get("/api/config") async def get_config(): return ctrl.config.config @app.put("/api/config") async def update_config(update: ConfigUpdate): import asyncio loop = asyncio.get_event_loop() engine_reloaded, message = await loop.run_in_executor( None, ctrl.apply_settings, update.settings ) return { "status": "ok", "message": message, "engine_reloaded": engine_reloaded, } # ── Devices ──────────────────────────────────────────── @app.get("/api/audio-devices") async def get_audio_devices(): return {"devices": ctrl.get_audio_devices()} @app.get("/api/compute-devices") async def get_compute_devices(): return {"devices": ctrl.get_compute_devices()} # ── Engine ───────────────────────────────────────────── @app.post("/api/reload-engine") async def reload_engine(): import asyncio loop = asyncio.get_event_loop() success, message = await loop.run_in_executor( None, ctrl.reload_engine ) if not success: raise HTTPException(status_code=500, detail=message) return {"status": "ok", "message": message} # ── Updates ──────────────────────────────────────────── @app.get("/api/check-update") async def check_update(): return ctrl.check_for_updates() @app.post("/api/skip-version") async def skip_version(req: SkipVersionRequest): ctrl.skip_version(req.version) return {"status": "ok"} # ── Managed Mode Auth Proxy ──────────────────────────── @app.post("/api/login") async def login(req: LoginRequest): """Proxy login to the transcription proxy server.""" import requests as http_requests try: resp = http_requests.post( f"{req.server_url}/api/auth/login", json={"email": req.email, "password": req.password}, timeout=10, ) if resp.status_code == 200: data = resp.json() ctrl.config.set('remote.auth_token', data.get('token', '')) ctrl.config.set('remote.server_url', req.server_url) ctrl.config.set('remote.email', req.email) return {"status": "ok", "token": data.get('token', '')} else: raise HTTPException(status_code=resp.status_code, detail=resp.text) except http_requests.RequestException as e: raise HTTPException(status_code=502, detail=str(e)) @app.post("/api/register") async def register(req: RegisterRequest): """Proxy registration to the transcription proxy server.""" import requests as http_requests try: resp = http_requests.post( f"{req.server_url}/api/auth/register", json={"email": req.email, "password": req.password}, timeout=10, ) if resp.status_code in (200, 201): return {"status": "ok", "data": resp.json()} else: raise HTTPException(status_code=resp.status_code, detail=resp.text) except http_requests.RequestException as e: raise HTTPException(status_code=502, detail=str(e)) @app.get("/api/balance") async def get_balance(): """Proxy balance check to the transcription proxy server.""" import requests as http_requests server_url = ctrl.config.get('remote.server_url', '') token = ctrl.config.get('remote.auth_token', '') if not server_url or not token: raise HTTPException(status_code=400, detail="Not logged in to managed service") try: resp = http_requests.get( f"{server_url}/api/billing/balance", headers={"Authorization": f"Bearer {token}"}, timeout=10, ) if resp.status_code == 200: return resp.json() else: raise HTTPException(status_code=resp.status_code, detail=resp.text) except http_requests.RequestException as e: raise HTTPException(status_code=502, detail=str(e)) # ── Control WebSocket ────────────────────────────────── @app.websocket("/ws/control") async def websocket_control(websocket: WebSocket): """WebSocket channel for real-time state and transcription push.""" await websocket.accept() self.control_connections.append(websocket) # Send current status on connect try: await websocket.send_json({ "type": "state_changed", "state": ctrl.state, "message": "Connected", }) except Exception: pass try: while True: # Keep alive -- client sends pings await websocket.receive_text() except Exception: if websocket in self.control_connections: self.control_connections.remove(websocket) # ── Mount the existing OBS display routes ────────────── # The OBS display (GET / and /ws) is handled by the # TranscriptionWebServer which shares the same Uvicorn # instance. We mount it as a sub-application so the # existing OBS URLs continue to work. if ctrl.web_server: app.mount("/obs", ctrl.web_server.app)