perf/pipeline-improvements #1

Merged
jknapp merged 18 commits from perf/pipeline-improvements into main 2026-03-21 04:53:45 +00:00
9 changed files with 39 additions and 30 deletions
Showing only changes of commit 42ccd3e21d - Show all commits

View File

@@ -49,8 +49,9 @@ def test_diarize_threading_progress(monkeypatch):
# Mock pipeline that takes ~5 seconds # Mock pipeline that takes ~5 seconds
def slow_pipeline(file_path, **kwargs): def slow_pipeline(file_path, **kwargs):
time.sleep(5) time.sleep(5)
# Return a mock diarization result # Return a mock diarization result (use spec=object to prevent
mock_result = MagicMock() # hasattr returning True for speaker_diarization)
mock_result = MagicMock(spec=[])
mock_track = MagicMock() mock_track = MagicMock()
mock_track.start = 0.0 mock_track.start = 0.0
mock_track.end = 5.0 mock_track.end = 5.0

View File

@@ -39,7 +39,11 @@ pub fn ai_chat(
if response.msg_type == "error" { if response.msg_type == "error" {
return Err(format!( return Err(format!(
"AI error: {}", "AI error: {}",
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown") response
.payload
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
)); ));
} }

View File

@@ -33,7 +33,11 @@ pub fn export_transcript(
if response.msg_type == "error" { if response.msg_type == "error" {
return Err(format!( return Err(format!(
"Export error: {}", "Export error: {}",
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown") response
.payload
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
)); ));
} }

View File

@@ -22,9 +22,7 @@ pub fn llama_start(
threads: Option<u32>, threads: Option<u32>,
) -> Result<LlamaStatus, String> { ) -> Result<LlamaStatus, String> {
let config = LlamaConfig { let config = LlamaConfig {
binary_path: PathBuf::from( binary_path: PathBuf::from(binary_path.unwrap_or_else(|| "llama-server".to_string())),
binary_path.unwrap_or_else(|| "llama-server".to_string()),
),
model_path: PathBuf::from(model_path), model_path: PathBuf::from(model_path),
port: port.unwrap_or(0), port: port.unwrap_or(0),
n_gpu_layers: n_gpu_layers.unwrap_or(0), n_gpu_layers: n_gpu_layers.unwrap_or(0),

View File

@@ -33,7 +33,11 @@ pub fn transcribe_file(
if response.msg_type == "error" { if response.msg_type == "error" {
return Err(format!( return Err(format!(
"Transcription error: {}", "Transcription error: {}",
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown") response
.payload
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
)); ));
} }
@@ -42,9 +46,7 @@ pub fn transcribe_file(
/// Download and validate the diarization model via the Python sidecar. /// Download and validate the diarization model via the Python sidecar.
#[tauri::command] #[tauri::command]
pub fn download_diarize_model( pub fn download_diarize_model(hf_token: String) -> Result<Value, String> {
hf_token: String,
) -> Result<Value, String> {
let manager = sidecar(); let manager = sidecar();
manager.ensure_running()?; manager.ensure_running()?;
@@ -116,7 +118,11 @@ pub fn run_pipeline(
if response.msg_type == "error" { if response.msg_type == "error" {
return Err(format!( return Err(format!(
"Pipeline error: {}", "Pipeline error: {}",
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown") response
.payload
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
)); ));
} }

View File

@@ -96,11 +96,7 @@ pub fn create_tables(conn: &Connection) -> Result<(), DatabaseError> {
)?; )?;
// Initialize schema version if empty // Initialize schema version if empty
let count: i32 = conn.query_row( let count: i32 = conn.query_row("SELECT COUNT(*) FROM schema_version", [], |row| row.get(0))?;
"SELECT COUNT(*) FROM schema_version",
[],
|row| row.get(0),
)?;
if count == 0 { if count == 0 {
conn.execute( conn.execute(
"INSERT INTO schema_version (version) VALUES (?1)", "INSERT INTO schema_version (version) VALUES (?1)",

View File

@@ -237,11 +237,7 @@ impl LlamaManager {
/// Get the current status. /// Get the current status.
pub fn status(&self) -> LlamaStatus { pub fn status(&self) -> LlamaStatus {
let running = self let running = self.process.lock().ok().map_or(false, |p| p.is_some());
.process
.lock()
.ok()
.map_or(false, |p| p.is_some());
let port = self.port.lock().ok().map_or(0, |p| *p); let port = self.port.lock().ok().map_or(0, |p| *p);
let model = self let model = self
.model_path .model_path

View File

@@ -108,7 +108,10 @@ impl SidecarManager {
} }
} }
// Non-JSON or non-ready line — skip and keep waiting // Non-JSON or non-ready line — skip and keep waiting
eprintln!("[sidecar-rs] Skipping pre-ready line: {}", &trimmed[..trimmed.len().min(200)]); eprintln!(
"[sidecar-rs] Skipping pre-ready line: {}",
&trimmed[..trimmed.len().min(200)]
);
continue; continue;
} }
} }
@@ -165,7 +168,10 @@ impl SidecarManager {
let response: IPCMessage = match serde_json::from_str(trimmed) { let response: IPCMessage = match serde_json::from_str(trimmed) {
Ok(msg) => msg, Ok(msg) => msg,
Err(_) => { Err(_) => {
eprintln!("[sidecar-rs] Skipping non-JSON line: {}", &trimmed[..trimmed.len().min(200)]); eprintln!(
"[sidecar-rs] Skipping non-JSON line: {}",
&trimmed[..trimmed.len().min(200)]
);
continue; continue;
} }
}; };
@@ -226,8 +232,8 @@ impl SidecarManager {
if trimmed.is_empty() { if trimmed.is_empty() {
continue; continue;
} }
let response: IPCMessage = serde_json::from_str(trimmed) let response: IPCMessage =
.map_err(|e| format!("Parse error: {e}"))?; serde_json::from_str(trimmed).map_err(|e| format!("Parse error: {e}"))?;
// Forward intermediate messages via callback, return the final result/error // Forward intermediate messages via callback, return the final result/error
let is_intermediate = matches!( let is_intermediate = matches!(

View File

@@ -15,12 +15,10 @@ pub struct AppState {
impl AppState { impl AppState {
pub fn new() -> Result<Self, String> { pub fn new() -> Result<Self, String> {
let data_dir = LlamaManager::data_dir(); let data_dir = LlamaManager::data_dir();
std::fs::create_dir_all(&data_dir) std::fs::create_dir_all(&data_dir).map_err(|e| format!("Cannot create data dir: {e}"))?;
.map_err(|e| format!("Cannot create data dir: {e}"))?;
let db_path = data_dir.join("voice_to_notes.db"); let db_path = data_dir.join("voice_to_notes.db");
let conn = db::open_database(&db_path) let conn = db::open_database(&db_path).map_err(|e| format!("Cannot open database: {e}"))?;
.map_err(|e| format!("Cannot open database: {e}"))?;
Ok(Self { Ok(Self {
db: Mutex::new(conn), db: Mutex::new(conn),