Fix cargo fmt formatting and diarize threading test mock
This commit is contained in:
@@ -49,8 +49,9 @@ def test_diarize_threading_progress(monkeypatch):
|
||||
# Mock pipeline that takes ~5 seconds
|
||||
def slow_pipeline(file_path, **kwargs):
|
||||
time.sleep(5)
|
||||
# Return a mock diarization result
|
||||
mock_result = MagicMock()
|
||||
# Return a mock diarization result (use spec=object to prevent
|
||||
# hasattr returning True for speaker_diarization)
|
||||
mock_result = MagicMock(spec=[])
|
||||
mock_track = MagicMock()
|
||||
mock_track.start = 0.0
|
||||
mock_track.end = 5.0
|
||||
|
||||
@@ -39,7 +39,11 @@ pub fn ai_chat(
|
||||
if response.msg_type == "error" {
|
||||
return Err(format!(
|
||||
"AI error: {}",
|
||||
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown")
|
||||
response
|
||||
.payload
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,11 @@ pub fn export_transcript(
|
||||
if response.msg_type == "error" {
|
||||
return Err(format!(
|
||||
"Export error: {}",
|
||||
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown")
|
||||
response
|
||||
.payload
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -22,9 +22,7 @@ pub fn llama_start(
|
||||
threads: Option<u32>,
|
||||
) -> Result<LlamaStatus, String> {
|
||||
let config = LlamaConfig {
|
||||
binary_path: PathBuf::from(
|
||||
binary_path.unwrap_or_else(|| "llama-server".to_string()),
|
||||
),
|
||||
binary_path: PathBuf::from(binary_path.unwrap_or_else(|| "llama-server".to_string())),
|
||||
model_path: PathBuf::from(model_path),
|
||||
port: port.unwrap_or(0),
|
||||
n_gpu_layers: n_gpu_layers.unwrap_or(0),
|
||||
|
||||
@@ -33,7 +33,11 @@ pub fn transcribe_file(
|
||||
if response.msg_type == "error" {
|
||||
return Err(format!(
|
||||
"Transcription error: {}",
|
||||
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown")
|
||||
response
|
||||
.payload
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
));
|
||||
}
|
||||
|
||||
@@ -42,9 +46,7 @@ pub fn transcribe_file(
|
||||
|
||||
/// Download and validate the diarization model via the Python sidecar.
|
||||
#[tauri::command]
|
||||
pub fn download_diarize_model(
|
||||
hf_token: String,
|
||||
) -> Result<Value, String> {
|
||||
pub fn download_diarize_model(hf_token: String) -> Result<Value, String> {
|
||||
let manager = sidecar();
|
||||
manager.ensure_running()?;
|
||||
|
||||
@@ -116,7 +118,11 @@ pub fn run_pipeline(
|
||||
if response.msg_type == "error" {
|
||||
return Err(format!(
|
||||
"Pipeline error: {}",
|
||||
response.payload.get("message").and_then(|v| v.as_str()).unwrap_or("unknown")
|
||||
response
|
||||
.payload
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown")
|
||||
));
|
||||
}
|
||||
|
||||
|
||||
@@ -96,11 +96,7 @@ pub fn create_tables(conn: &Connection) -> Result<(), DatabaseError> {
|
||||
)?;
|
||||
|
||||
// Initialize schema version if empty
|
||||
let count: i32 = conn.query_row(
|
||||
"SELECT COUNT(*) FROM schema_version",
|
||||
[],
|
||||
|row| row.get(0),
|
||||
)?;
|
||||
let count: i32 = conn.query_row("SELECT COUNT(*) FROM schema_version", [], |row| row.get(0))?;
|
||||
if count == 0 {
|
||||
conn.execute(
|
||||
"INSERT INTO schema_version (version) VALUES (?1)",
|
||||
|
||||
@@ -237,11 +237,7 @@ impl LlamaManager {
|
||||
|
||||
/// Get the current status.
|
||||
pub fn status(&self) -> LlamaStatus {
|
||||
let running = self
|
||||
.process
|
||||
.lock()
|
||||
.ok()
|
||||
.map_or(false, |p| p.is_some());
|
||||
let running = self.process.lock().ok().map_or(false, |p| p.is_some());
|
||||
let port = self.port.lock().ok().map_or(0, |p| *p);
|
||||
let model = self
|
||||
.model_path
|
||||
|
||||
@@ -108,7 +108,10 @@ impl SidecarManager {
|
||||
}
|
||||
}
|
||||
// Non-JSON or non-ready line — skip and keep waiting
|
||||
eprintln!("[sidecar-rs] Skipping pre-ready line: {}", &trimmed[..trimmed.len().min(200)]);
|
||||
eprintln!(
|
||||
"[sidecar-rs] Skipping pre-ready line: {}",
|
||||
&trimmed[..trimmed.len().min(200)]
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -165,7 +168,10 @@ impl SidecarManager {
|
||||
let response: IPCMessage = match serde_json::from_str(trimmed) {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => {
|
||||
eprintln!("[sidecar-rs] Skipping non-JSON line: {}", &trimmed[..trimmed.len().min(200)]);
|
||||
eprintln!(
|
||||
"[sidecar-rs] Skipping non-JSON line: {}",
|
||||
&trimmed[..trimmed.len().min(200)]
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@@ -226,8 +232,8 @@ impl SidecarManager {
|
||||
if trimmed.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let response: IPCMessage = serde_json::from_str(trimmed)
|
||||
.map_err(|e| format!("Parse error: {e}"))?;
|
||||
let response: IPCMessage =
|
||||
serde_json::from_str(trimmed).map_err(|e| format!("Parse error: {e}"))?;
|
||||
|
||||
// Forward intermediate messages via callback, return the final result/error
|
||||
let is_intermediate = matches!(
|
||||
|
||||
@@ -15,12 +15,10 @@ pub struct AppState {
|
||||
impl AppState {
|
||||
pub fn new() -> Result<Self, String> {
|
||||
let data_dir = LlamaManager::data_dir();
|
||||
std::fs::create_dir_all(&data_dir)
|
||||
.map_err(|e| format!("Cannot create data dir: {e}"))?;
|
||||
std::fs::create_dir_all(&data_dir).map_err(|e| format!("Cannot create data dir: {e}"))?;
|
||||
|
||||
let db_path = data_dir.join("voice_to_notes.db");
|
||||
let conn = db::open_database(&db_path)
|
||||
.map_err(|e| format!("Cannot open database: {e}"))?;
|
||||
let conn = db::open_database(&db_path).map_err(|e| format!("Cannot open database: {e}"))?;
|
||||
|
||||
Ok(Self {
|
||||
db: Mutex::new(conn),
|
||||
|
||||
Reference in New Issue
Block a user