129 lines
3.7 KiB
Python
129 lines
3.7 KiB
Python
|
|
"""Utilities for detecting and managing compute devices (CPU/GPU)."""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from typing import List, Tuple
|
||
|
|
|
||
|
|
|
||
|
|
class DeviceManager:
|
||
|
|
"""Manages device detection and selection for transcription."""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
"""Initialize device manager and detect available devices."""
|
||
|
|
self.available_devices = self._detect_devices()
|
||
|
|
self.current_device = self.available_devices[0] if self.available_devices else "cpu"
|
||
|
|
|
||
|
|
def _detect_devices(self) -> List[str]:
|
||
|
|
"""
|
||
|
|
Detect available compute devices.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of available device names
|
||
|
|
"""
|
||
|
|
devices = ["cpu"]
|
||
|
|
|
||
|
|
# Check for CUDA (NVIDIA GPU)
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
devices.append("cuda")
|
||
|
|
|
||
|
|
# Check for MPS (Apple Silicon GPU)
|
||
|
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||
|
|
devices.append("mps")
|
||
|
|
|
||
|
|
return devices
|
||
|
|
|
||
|
|
def get_device_info(self) -> List[Tuple[str, str]]:
|
||
|
|
"""
|
||
|
|
Get detailed information about available devices.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of (device_name, device_description) tuples
|
||
|
|
"""
|
||
|
|
info = []
|
||
|
|
|
||
|
|
for device in self.available_devices:
|
||
|
|
if device == "cpu":
|
||
|
|
info.append(("cpu", "CPU"))
|
||
|
|
elif device == "cuda":
|
||
|
|
try:
|
||
|
|
gpu_name = torch.cuda.get_device_name(0)
|
||
|
|
info.append(("cuda", f"CUDA GPU: {gpu_name}"))
|
||
|
|
except:
|
||
|
|
info.append(("cuda", "CUDA GPU"))
|
||
|
|
elif device == "mps":
|
||
|
|
info.append(("mps", "Apple Silicon GPU (MPS)"))
|
||
|
|
|
||
|
|
return info
|
||
|
|
|
||
|
|
def set_device(self, device: str) -> bool:
|
||
|
|
"""
|
||
|
|
Set the current device for transcription.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
device: Device name ('cpu', 'cuda', 'mps', or 'auto')
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if device was set successfully, False otherwise
|
||
|
|
"""
|
||
|
|
if device == "auto":
|
||
|
|
# Auto-select best available device
|
||
|
|
if "cuda" in self.available_devices:
|
||
|
|
self.current_device = "cuda"
|
||
|
|
elif "mps" in self.available_devices:
|
||
|
|
self.current_device = "mps"
|
||
|
|
else:
|
||
|
|
self.current_device = "cpu"
|
||
|
|
return True
|
||
|
|
|
||
|
|
if device in self.available_devices:
|
||
|
|
self.current_device = device
|
||
|
|
return True
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_device(self) -> str:
|
||
|
|
"""
|
||
|
|
Get the currently selected device.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Current device name
|
||
|
|
"""
|
||
|
|
return self.current_device
|
||
|
|
|
||
|
|
def is_gpu_available(self) -> bool:
|
||
|
|
"""
|
||
|
|
Check if any GPU is available.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if CUDA or MPS is available
|
||
|
|
"""
|
||
|
|
return "cuda" in self.available_devices or "mps" in self.available_devices
|
||
|
|
|
||
|
|
def get_device_for_whisper(self) -> str:
|
||
|
|
"""
|
||
|
|
Get device string formatted for faster-whisper.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Device string for faster-whisper ('cpu', 'cuda', etc.)
|
||
|
|
"""
|
||
|
|
if self.current_device == "mps":
|
||
|
|
# faster-whisper doesn't support MPS, fall back to CPU
|
||
|
|
return "cpu"
|
||
|
|
return self.current_device
|
||
|
|
|
||
|
|
def get_compute_type(self) -> str:
|
||
|
|
"""
|
||
|
|
Get the appropriate compute type for the current device.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Compute type string for faster-whisper
|
||
|
|
"""
|
||
|
|
if self.current_device == "cuda":
|
||
|
|
# Use float16 for GPU for better performance
|
||
|
|
return "float16"
|
||
|
|
else:
|
||
|
|
# Use int8 for CPU for better performance
|
||
|
|
return "int8"
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
return f"DeviceManager(current={self.current_device}, available={self.available_devices})"
|