Coverage for src/meta_learning/meta_learning_modules/hardware_utils.py: 30%

259 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2Modern Hardware Support Utilities for Meta-Learning 

3================================================ 

4 

5Comprehensive support for modern hardware accelerators including: 

6- NVIDIA GPUs (RTX 4090, A100, H100, etc.) 

7- Apple Silicon (M1/M2/M3/M4 with MPS) 

8- Multi-GPU distributed training 

9- Mixed precision training (FP16, BF16) 

10- Memory optimization and efficient computation 

11 

12This module provides hardware abstraction that automatically detects 

13and utilizes the best available hardware for meta-learning workloads. 

14 

15Author: Benedict Chen (benedict@benedictchen.com) 

16""" 

17 

18import torch 

19import torch.nn as nn 

20import torch.nn.functional as F 

21from torch.cuda.amp import autocast, GradScaler 

22from typing import Dict, List, Tuple, Optional, Any, Union 

23import logging 

24import psutil 

25import gc 

26from contextlib import contextmanager 

27from dataclasses import dataclass 

28import warnings 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33@dataclass 

34class HardwareConfig: 

35 """Configuration for modern hardware utilization.""" 

36 # Device selection 

37 device: Optional[str] = None # Auto-detect if None 

38 use_mixed_precision: bool = True # AMP for faster training 

39 precision_dtype: str = "float16" # "float16", "bfloat16", or "float32" 

40 

41 # Multi-GPU settings 

42 use_data_parallel: bool = False # Use DataParallel 

43 use_distributed: bool = False # Use DistributedDataParallel  

44 world_size: int = 1 

45 rank: int = 0 

46 

47 # Memory optimization 

48 gradient_checkpointing: bool = False # Trade compute for memory 

49 memory_efficient: bool = True # Enable memory optimizations 

50 max_memory_fraction: float = 0.9 # Max GPU memory to use 

51 

52 # Apple Silicon specific 

53 use_mps_fallback: bool = True # Fallback for unsupported ops 

54 

55 # Performance tuning 

56 compile_model: bool = False # PyTorch 2.0 compilation 

57 channels_last: bool = False # Memory format optimization 

58 benchmark_mode: bool = True # cuDNN benchmark mode 

59 

60 

61class HardwareManager: 

62 """Manages modern hardware resources for meta-learning.""" 

63 

64 def __init__(self, config: Optional[HardwareConfig] = None): 

65 self.config = config or HardwareConfig() 

66 self.device = self._detect_best_device() 

67 self.scaler = None 

68 self.is_distributed = False 

69 

70 # Initialize hardware-specific settings 

71 self._initialize_hardware() 

72 

73 def _detect_best_device(self) -> torch.device: 

74 """Automatically detect the best available device.""" 

75 if self.config.device: 

76 return torch.device(self.config.device) 

77 

78 # Priority order: CUDA > MPS > CPU 

79 if torch.cuda.is_available(): 79 ↛ 80line 79 didn't jump to line 80 because the condition on line 79 was never true

80 device_name = torch.cuda.get_device_name(0) 

81 gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 

82 logger.info(f"🚀 Using CUDA GPU: {device_name} ({gpu_memory:.1f}GB)") 

83 return torch.device("cuda") 

84 

85 elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 85 ↛ 90line 85 didn't jump to line 90 because the condition on line 85 was always true

86 logger.info("🍎 Using Apple Silicon MPS acceleration") 

87 return torch.device("mps") 

88 

89 else: 

90 cpu_count = psutil.cpu_count() 

91 ram_gb = psutil.virtual_memory().total / 1e9 

92 logger.info(f"💻 Using CPU: {cpu_count} cores, {ram_gb:.1f}GB RAM") 

93 return torch.device("cpu") 

94 

95 def _initialize_hardware(self): 

96 """Initialize hardware-specific optimizations.""" 

97 if self.device.type == "cuda": 

98 self._setup_cuda() 

99 elif self.device.type == "mps": 99 ↛ 102line 99 didn't jump to line 102 because the condition on line 99 was always true

100 self._setup_mps() 

101 else: 

102 self._setup_cpu() 

103 

104 def _setup_cuda(self): 

105 """Setup CUDA-specific optimizations.""" 

106 # Enable cuDNN benchmark for consistent input sizes 

107 if self.config.benchmark_mode: 107 ↛ 111line 107 didn't jump to line 111 because the condition on line 107 was always true

108 torch.backends.cudnn.benchmark = True 

109 

110 # Mixed precision training 

111 if self.config.use_mixed_precision: 111 ↛ 116line 111 didn't jump to line 116 because the condition on line 111 was always true

112 self.scaler = GradScaler() 

113 logger.info("⚡ Enabled mixed precision training (AMP)") 

114 

115 # Memory management 

116 if self.config.max_memory_fraction < 1.0: 116 ↛ 120line 116 didn't jump to line 120 because the condition on line 116 was always true

117 torch.cuda.set_per_process_memory_fraction(self.config.max_memory_fraction) 

118 

119 # Log GPU information 

120 props = torch.cuda.get_device_properties(0) 

121 logger.info(f"🎯 GPU: {props.name}, Compute: {props.major}.{props.minor}, Memory: {props.total_memory/1e9:.1f}GB") 

122 

123 def _setup_mps(self): 

124 """Setup Apple Silicon MPS optimizations.""" 

125 # MPS doesn't support mixed precision yet (as of PyTorch 2.1) 

126 if self.config.use_mixed_precision: 126 ↛ 130line 126 didn't jump to line 130 because the condition on line 126 was always true

127 logger.warning("⚠️ Mixed precision not supported on MPS, disabling") 

128 self.config.use_mixed_precision = False 

129 

130 logger.info("🍎 Configured for Apple Silicon optimization") 

131 

132 def _setup_cpu(self): 

133 """Setup CPU optimizations.""" 

134 # Set optimal thread count 

135 if torch.get_num_threads() < psutil.cpu_count(): 

136 torch.set_num_threads(psutil.cpu_count()) 

137 

138 # Disable mixed precision on CPU 

139 if self.config.use_mixed_precision: 

140 logger.warning("⚠️ Mixed precision not efficient on CPU, disabling") 

141 self.config.use_mixed_precision = False 

142 

143 def prepare_model(self, model: nn.Module) -> nn.Module: 

144 """Prepare model for optimal hardware utilization.""" 

145 # Move to device 

146 model = model.to(self.device) 

147 

148 # Enable gradient checkpointing if requested 

149 if self.config.gradient_checkpointing and hasattr(model, 'gradient_checkpointing_enable'): 149 ↛ 150line 149 didn't jump to line 150 because the condition on line 149 was never true

150 model.gradient_checkpointing_enable() 

151 logger.info("💾 Enabled gradient checkpointing for memory efficiency") 

152 

153 # Channels-last memory format (for conv nets) 

154 if self.config.channels_last and hasattr(model, 'to'): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 try: 

156 model = model.to(memory_format=torch.channels_last) 

157 logger.info("🔄 Using channels-last memory format") 

158 except Exception as e: 

159 logger.warning(f"Could not use channels-last format: {e}") 

160 

161 # Multi-GPU setup 

162 if self.config.use_data_parallel and torch.cuda.device_count() > 1: 162 ↛ 163line 162 didn't jump to line 163 because the condition on line 162 was never true

163 model = nn.DataParallel(model) 

164 logger.info(f"🔗 Using DataParallel across {torch.cuda.device_count()} GPUs") 

165 

166 # PyTorch 2.0 compilation 

167 if self.config.compile_model and hasattr(torch, 'compile'): 167 ↛ 168line 167 didn't jump to line 168 because the condition on line 167 was never true

168 try: 

169 model = torch.compile(model) 

170 logger.info("⚡ Model compiled with PyTorch 2.0") 

171 except Exception as e: 

172 logger.warning(f"Model compilation failed: {e}") 

173 

174 return model 

175 

176 def prepare_data(self, data: Union[torch.Tensor, Tuple, List]) -> Any: 

177 """Prepare data tensors for optimal hardware utilization.""" 

178 if isinstance(data, torch.Tensor): 

179 tensor = data.to(self.device, non_blocking=True) 

180 

181 # Convert to channels-last if enabled and applicable 

182 if (self.config.channels_last and 

183 len(tensor.shape) == 4 and 

184 self.device.type == "cuda"): 

185 try: 

186 tensor = tensor.to(memory_format=torch.channels_last) 

187 except: 

188 pass # Ignore if not applicable 

189 

190 return tensor 

191 

192 elif isinstance(data, (tuple, list)): 

193 return type(data)(self.prepare_data(item) for item in data) 

194 

195 elif isinstance(data, dict): 

196 return {key: self.prepare_data(value) for key, value in data.items()} 

197 

198 else: 

199 return data 

200 

201 @contextmanager 

202 def autocast_context(self): 

203 """Context manager for mixed precision computation.""" 

204 if self.config.use_mixed_precision and self.device.type == "cuda": 

205 dtype = torch.float16 if self.config.precision_dtype == "float16" else torch.bfloat16 

206 with autocast(enabled=True, dtype=dtype): 

207 yield 

208 else: 

209 yield 

210 

211 def backward_and_step(self, loss: torch.Tensor, optimizer: torch.optim.Optimizer) -> torch.Tensor: 

212 """Perform backward pass with hardware optimizations.""" 

213 if self.config.use_mixed_precision and self.scaler: 

214 # Mixed precision backward 

215 self.scaler.scale(loss).backward() 

216 self.scaler.step(optimizer) 

217 self.scaler.update() 

218 else: 

219 # Regular backward 

220 loss.backward() 

221 optimizer.step() 

222 

223 return loss 

224 

225 def get_memory_stats(self) -> Dict[str, Any]: 

226 """Get current memory usage statistics.""" 

227 stats = {} 

228 

229 if self.device.type == "cuda": 

230 stats.update({ 

231 'gpu_memory_allocated': torch.cuda.memory_allocated() / 1e9, 

232 'gpu_memory_reserved': torch.cuda.memory_reserved() / 1e9, 

233 'gpu_memory_max_allocated': torch.cuda.max_memory_allocated() / 1e9, 

234 'gpu_utilization': self._get_gpu_utilization() 

235 }) 

236 

237 # CPU/system memory 

238 vm = psutil.virtual_memory() 

239 stats.update({ 

240 'cpu_memory_used': vm.used / 1e9, 

241 'cpu_memory_total': vm.total / 1e9, 

242 'cpu_memory_percent': vm.percent 

243 }) 

244 

245 return stats 

246 

247 def _get_gpu_utilization(self) -> float: 

248 """Get GPU utilization percentage.""" 

249 try: 

250 import nvidia_ml_py3 as nvml 

251 nvml.nvmlInit() 

252 handle = nvml.nvmlDeviceGetHandleByIndex(0) 

253 util = nvml.nvmlDeviceGetUtilizationRates(handle) 

254 return util.gpu 

255 except: 

256 return 0.0 

257 

258 def clear_cache(self): 

259 """Clear GPU/CPU cache to free memory.""" 

260 if self.device.type == "cuda": 

261 torch.cuda.empty_cache() 

262 torch.cuda.synchronize() 

263 

264 # Python garbage collection 

265 gc.collect() 

266 

267 def benchmark_device(self, model: nn.Module, input_shape: Tuple[int, ...]) -> Dict[str, float]: 

268 """Benchmark model performance on current device.""" 

269 model = model.to(self.device) 

270 dummy_input = torch.randn(input_shape).to(self.device) 

271 

272 # Warmup 

273 for _ in range(10): 

274 with torch.no_grad(): 

275 _ = model(dummy_input) 

276 

277 if self.device.type == "cuda": 

278 torch.cuda.synchronize() 

279 

280 # Benchmark inference 

281 import time 

282 start_time = time.time() 

283 

284 with torch.no_grad(): 

285 for _ in range(100): 

286 _ = model(dummy_input) 

287 

288 if self.device.type == "cuda": 

289 torch.cuda.synchronize() 

290 

291 inference_time = (time.time() - start_time) / 100 

292 

293 # Benchmark training 

294 optimizer = torch.optim.Adam(model.parameters()) 

295 

296 if self.device.type == "cuda": 

297 torch.cuda.synchronize() 

298 

299 start_time = time.time() 

300 

301 for _ in range(50): 

302 optimizer.zero_grad() 

303 output = model(dummy_input) 

304 loss = output.mean() 

305 loss.backward() 

306 optimizer.step() 

307 

308 if self.device.type == "cuda": 

309 torch.cuda.synchronize() 

310 

311 training_time = (time.time() - start_time) / 50 

312 

313 return { 

314 'inference_time_ms': inference_time * 1000, 

315 'training_time_ms': training_time * 1000, 

316 'device': str(self.device), 

317 'memory_used_gb': self.get_memory_stats().get('gpu_memory_allocated', 0) 

318 } 

319 

320 

321class MultiGPUManager: 

322 """Manager for multi-GPU distributed training.""" 

323 

324 def __init__(self, config: HardwareConfig): 

325 self.config = config 

326 self.world_size = config.world_size 

327 self.rank = config.rank 

328 self.is_initialized = False 

329 

330 def setup_distributed(self, backend: str = "nccl"): 

331 """Setup distributed training.""" 

332 if not torch.cuda.is_available() or torch.cuda.device_count() < 2: 

333 logger.warning("Distributed training requires multiple CUDA GPUs") 

334 return False 

335 

336 try: 

337 if not torch.distributed.is_initialized(): 

338 torch.distributed.init_process_group( 

339 backend=backend, 

340 world_size=self.world_size, 

341 rank=self.rank 

342 ) 

343 self.is_initialized = True 

344 logger.info(f"🔗 Initialized distributed training: rank {self.rank}/{self.world_size}") 

345 

346 return True 

347 except Exception as e: 

348 logger.error(f"Failed to setup distributed training: {e}") 

349 return False 

350 

351 def wrap_model(self, model: nn.Module) -> nn.Module: 

352 """Wrap model for distributed training.""" 

353 if not self.is_initialized: 

354 return model 

355 

356 from torch.nn.parallel import DistributedDataParallel as DDP 

357 

358 device = torch.device(f"cuda:{self.rank}") 

359 model = model.to(device) 

360 model = DDP(model, device_ids=[self.rank]) 

361 

362 logger.info(f"📦 Wrapped model with DistributedDataParallel on GPU {self.rank}") 

363 return model 

364 

365 def cleanup(self): 

366 """Cleanup distributed resources.""" 

367 if self.is_initialized and torch.distributed.is_initialized(): 

368 torch.distributed.destroy_process_group() 

369 self.is_initialized = False 

370 

371 

372def get_optimal_batch_size(model: nn.Module, input_shape: Tuple[int, ...], 

373 device: torch.device) -> int: 

374 """Find optimal batch size for current hardware.""" 

375 if device.type != "cuda": 

376 return 32 # Conservative default for CPU/MPS 

377 

378 # Start with a reasonable batch size 

379 batch_size = 16 

380 max_batch_size = 1024 

381 

382 model = model.to(device) 

383 model.train() # Enable training mode for accurate memory estimation 

384 

385 optimizer = torch.optim.Adam(model.parameters()) 

386 

387 while batch_size <= max_batch_size: 

388 try: 

389 # Clear cache 

390 torch.cuda.empty_cache() 

391 

392 # Test batch 

393 dummy_input = torch.randn(batch_size, *input_shape[1:]).to(device) 

394 

395 optimizer.zero_grad() 

396 output = model(dummy_input) 

397 loss = output.mean() 

398 loss.backward() 

399 optimizer.step() 

400 

401 # If successful, try larger batch 

402 batch_size *= 2 

403 

404 except RuntimeError as e: 

405 if "out of memory" in str(e): 

406 # Return previous successful batch size 

407 optimal_batch_size = batch_size // 2 

408 logger.info(f"🎯 Optimal batch size: {optimal_batch_size}") 

409 return max(1, optimal_batch_size) 

410 else: 

411 raise e 

412 

413 return batch_size // 2 

414 

415 

416def create_hardware_manager(device: Optional[str] = None, 

417 use_mixed_precision: bool = True, 

418 **kwargs) -> HardwareManager: 

419 """Factory function to create hardware manager with optimal settings.""" 

420 config = HardwareConfig( 

421 device=device, 

422 use_mixed_precision=use_mixed_precision, 

423 **kwargs 

424 ) 

425 

426 return HardwareManager(config) 

427 

428 

429# Compatibility functions for easy integration 

430def auto_device() -> torch.device: 

431 """Get the best available device.""" 

432 manager = HardwareManager() 

433 return manager.device 

434 

435 

436def prepare_for_hardware(model: nn.Module, device: Optional[torch.device] = None) -> nn.Module: 

437 """Prepare model for optimal hardware usage.""" 

438 if device is None: 

439 device = auto_device() 

440 

441 manager = HardwareManager(HardwareConfig(device=str(device))) 

442 return manager.prepare_model(model) 

443 

444 

445def log_hardware_info(): 

446 """Log detailed hardware information.""" 

447 logger.info("🔧 Hardware Configuration:") 

448 

449 # CPU info 

450 cpu_count = psutil.cpu_count() 

451 ram_gb = psutil.virtual_memory().total / 1e9 

452 logger.info(f" CPU: {cpu_count} cores, {ram_gb:.1f}GB RAM") 

453 

454 # GPU info 

455 if torch.cuda.is_available(): 

456 for i in range(torch.cuda.device_count()): 

457 props = torch.cuda.get_device_properties(i) 

458 logger.info(f" GPU {i}: {props.name}, {props.total_memory/1e9:.1f}GB") 

459 

460 # Apple Silicon 

461 if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 

462 logger.info(" Apple Silicon MPS: Available") 

463 

464 # PyTorch info 

465 logger.info(f" PyTorch: {torch.__version__}") 

466 logger.info(f" CUDA: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")