Coverage for src/meta_learning/meta_learning_modules/config_factory.py: 9%

245 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:35 +0900

1#!/usr/bin/env python3 

2""" 

3Comprehensive Configuration Factory for ALL FIXME Solutions 

4========================================================= 

5 

6This module provides factory functions to create configurations for ALL 

7implemented FIXME solutions across all modules in the meta-learning package. 

8 

9Users can pick and choose which solutions to enable with overlapping 

10configurations handled intelligently. 

11 

12All configurations are research-accurate and production-ready. 

13""" 

14 

15from typing import Dict, List, Optional, Any, Union 

16from dataclasses import dataclass 

17 

18# Import all configuration classes 

19from .test_time_compute import TestTimeComputeConfig 

20from .few_shot_learning import PrototypicalConfig, MatchingConfig, RelationConfig 

21from .continual_meta_learning import ContinualMetaConfig, OnlineMetaConfig 

22from .maml_variants import MAMLConfig 

23from .utils import TaskConfiguration, EvaluationConfig 

24 

25 

26@dataclass 

27class ComprehensiveMetaLearningConfig: 

28 """ 

29 Master configuration class that encompasses ALL FIXME solutions. 

30  

31 Users can configure every aspect of the meta-learning pipeline from 

32 a single unified configuration object. 

33 """ 

34 # Test-Time Compute Configuration 

35 test_time_compute: Optional[TestTimeComputeConfig] = None 

36 

37 # Few-Shot Learning Configurations 

38 prototypical: Optional[PrototypicalConfig] = None 

39 matching: Optional[MatchingConfig] = None 

40 relation: Optional[RelationConfig] = None 

41 

42 # Continual Learning Configurations  

43 continual_meta: Optional[ContinualMetaConfig] = None 

44 online_meta: Optional[OnlineMetaConfig] = None 

45 

46 # MAML Configuration 

47 maml: Optional[MAMLConfig] = None 

48 

49 # Utility Configurations 

50 task: Optional[TaskConfiguration] = None 

51 evaluation: Optional[EvaluationConfig] = None 

52 

53 # Global settings that affect multiple modules 

54 global_seed: int = 42 

55 device: str = "auto" # "auto", "cpu", "cuda" 

56 verbose: bool = True 

57 

58 

59# ============================================================================= 

60# COMPREHENSIVE FACTORY FUNCTIONS FOR ALL FIXME SOLUTIONS 

61# ============================================================================= 

62 

63def create_all_fixme_solutions_config() -> ComprehensiveMetaLearningConfig: 

64 """ 

65 Create configuration that enables ALL implemented FIXME solutions. 

66  

67 COMPREHENSIVE: Every single FIXME solution across all modules enabled 

68 with balanced settings for research accuracy and performance. 

69 """ 

70 config = ComprehensiveMetaLearningConfig() 

71 

72 # Test-Time Compute: All solutions enabled 

73 config.test_time_compute = TestTimeComputeConfig() 

74 config.test_time_compute.compute_strategy = "hybrid" 

75 config.test_time_compute.use_process_reward = True 

76 config.test_time_compute.use_test_time_training = True 

77 config.test_time_compute.use_gradient_verification = True 

78 config.test_time_compute.use_chain_of_thought = True 

79 config.test_time_compute.cot_method = "attention_based" 

80 config.test_time_compute.use_optimal_allocation = True 

81 config.test_time_compute.use_adaptive_distribution = True 

82 

83 # Prototypical Networks: All FIXME solutions enabled 

84 config.prototypical = PrototypicalConfig() 

85 config.prototypical.use_uncertainty_aware_distances = True 

86 config.prototypical.use_hierarchical_prototypes = True 

87 config.prototypical.use_task_adaptive_prototypes = True 

88 config.prototypical.protonet_variant = "research_accurate" 

89 config.prototypical.multi_scale_features = True 

90 config.prototypical.adaptive_prototypes = True 

91 config.prototypical.uncertainty_estimation = True 

92 

93 # Matching Networks: Advanced attention mechanisms 

94 config.matching = MatchingConfig() 

95 config.matching.attention_mechanism = "scaled_dot_product" 

96 config.matching.context_encoding = True 

97 config.matching.support_set_encoding = "transformer" 

98 config.matching.bidirectional_lstm = True 

99 

100 # Relation Networks: Graph neural networks enabled 

101 config.relation = RelationConfig() 

102 config.relation.use_graph_neural_network = True 

103 config.relation.edge_features = True 

104 config.relation.self_attention = True 

105 

106 # Continual Learning: All EWC and Fisher solutions 

107 config.continual_meta = ContinualMetaConfig() 

108 config.continual_meta.ewc_method = "full" # Use full Fisher matrix 

109 config.continual_meta.fisher_estimation_method = "exact" 

110 config.continual_meta.fisher_accumulation_method = "ema" 

111 config.continual_meta.memory_consolidation_method = "ewc" 

112 config.continual_meta.use_task_specific_importance = True 

113 config.continual_meta.use_gradient_importance = True 

114 

115 # Online Meta-Learning: Advanced replay and adaptation 

116 config.online_meta = OnlineMetaConfig() 

117 config.online_meta.experience_replay = True 

118 config.online_meta.prioritized_replay = True 

119 config.online_meta.importance_sampling = True 

120 config.online_meta.adaptive_lr = True 

121 

122 # MAML: All functional forward solutions 

123 config.maml = MAMLConfig() 

124 config.maml.functional_forward_method = "higher_style" 

125 config.maml.maml_variant = "maml" 

126 config.maml.inner_lr = 0.01 

127 config.maml.inner_steps = 5 

128 config.maml.first_order = False 

129 

130 # Task Configuration: All difficulty estimation methods 

131 config.task = TaskConfiguration() 

132 config.task.difficulty_estimation_method = "entropy" # Can switch between methods 

133 

134 # Evaluation: All confidence interval methods 

135 config.evaluation = EvaluationConfig() 

136 config.evaluation.confidence_interval_method = "bca_bootstrap" 

137 config.evaluation.num_episodes = 600 

138 

139 return config 

140 

141 

142def create_research_accurate_config() -> ComprehensiveMetaLearningConfig: 

143 """ 

144 Create configuration focused on research accuracy over performance. 

145  

146 RESEARCH-FIRST: Prioritizes exact implementations from papers over speed. 

147 """ 

148 config = ComprehensiveMetaLearningConfig() 

149 

150 # Test-Time Compute: Research-accurate methods 

151 config.test_time_compute = TestTimeComputeConfig() 

152 config.test_time_compute.compute_strategy = "snell2024" 

153 config.test_time_compute.use_process_reward = True 

154 config.test_time_compute.prm_scoring_method = "weighted" 

155 

156 # Prototypical Networks: Pure original implementation 

157 config.prototypical = PrototypicalConfig() 

158 config.prototypical.use_original_implementation = True 

159 config.prototypical.use_squared_euclidean = True 

160 config.prototypical.prototype_method = "mean" 

161 

162 # Continual Learning: Kirkpatrick et al. 2017 exact 

163 config.continual_meta = ContinualMetaConfig() 

164 config.continual_meta.ewc_method = "diagonal" 

165 config.continual_meta.fisher_estimation_method = "empirical" 

166 config.continual_meta.fisher_sampling_method = "true_posterior" 

167 

168 # MAML: Original Finn et al. 2017 implementation 

169 config.maml = MAMLConfig() 

170 config.maml.maml_variant = "maml" 

171 config.maml.functional_forward_method = "basic" 

172 config.maml.first_order = False 

173 

174 # Evaluation: Standard meta-learning protocols 

175 config.evaluation = EvaluationConfig() 

176 config.evaluation.confidence_interval_method = "t_distribution" 

177 config.evaluation.num_episodes = 600 

178 

179 return config 

180 

181 

182def create_performance_optimized_config() -> ComprehensiveMetaLearningConfig: 

183 """ 

184 Create configuration optimized for performance and speed. 

185  

186 PERFORMANCE-FIRST: Balanced accuracy with computational efficiency. 

187 """ 

188 config = ComprehensiveMetaLearningConfig() 

189 

190 # Test-Time Compute: Fast configuration 

191 config.test_time_compute = TestTimeComputeConfig() 

192 config.test_time_compute.compute_strategy = "basic" 

193 config.test_time_compute.max_compute_budget = 100 

194 config.test_time_compute.min_compute_steps = 3 

195 config.test_time_compute.use_chain_of_thought = True 

196 config.test_time_compute.cot_method = "prototype_based" # Fastest method 

197 

198 # Prototypical Networks: Simple but effective 

199 config.prototypical = PrototypicalConfig() 

200 config.prototypical.protonet_variant = "simple" 

201 config.prototypical.multi_scale_features = False 

202 config.prototypical.adaptive_prototypes = False 

203 

204 # Continual Learning: Diagonal EWC for speed 

205 config.continual_meta = ContinualMetaConfig() 

206 config.continual_meta.ewc_method = "diagonal" 

207 config.continual_meta.fisher_estimation_method = "empirical" 

208 config.continual_meta.fisher_accumulation_method = "sum" 

209 

210 # MAML: First-order for speed 

211 config.maml = MAMLConfig() 

212 config.maml.maml_variant = "fomaml" # First-order MAML 

213 config.maml.functional_forward_method = "compiled" # PyTorch 2.0 optimization 

214 config.maml.inner_steps = 3 

215 

216 # Evaluation: Fast CI computation 

217 config.evaluation = EvaluationConfig() 

218 config.evaluation.confidence_interval_method = "bootstrap" 

219 config.evaluation.num_episodes = 300 # Reduced for speed 

220 

221 return config 

222 

223 

224def create_specific_solution_config( 

225 solutions: List[str] 

226) -> ComprehensiveMetaLearningConfig: 

227 """ 

228 Create configuration for specific FIXME solutions only. 

229  

230 Args: 

231 solutions: List of solution identifiers to enable 

232  

233 Available solutions: 

234 - "process_reward_model": Test-time compute process reward verification 

235 - "consistency_verification": Test-time training consistency checks 

236 - "gradient_verification": Gradient-based step verification 

237 - "attention_reasoning": Attention-based reasoning paths 

238 - "feature_reasoning": Feature-based reasoning decomposition 

239 - "prototype_reasoning": Prototype-distance reasoning steps 

240 - "uncertainty_distances": Uncertainty-aware distance metrics 

241 - "hierarchical_prototypes": Multi-level prototype structures 

242 - "task_adaptive_prototypes": Task-specific prototype initialization 

243 - "full_fisher": Full Fisher Information Matrix computation 

244 - "evcl": Elastic Variational Continual Learning 

245 - "kfac_fisher": Kronecker-factored Fisher approximation 

246 - "functional_forward": Advanced functional forward methods 

247 - "difficulty_estimation": Advanced difficulty estimation methods 

248 - "bootstrap_ci": Advanced confidence interval methods 

249 """ 

250 config = ComprehensiveMetaLearningConfig() 

251 

252 # Initialize basic configurations 

253 config.test_time_compute = TestTimeComputeConfig() 

254 config.prototypical = PrototypicalConfig() 

255 config.continual_meta = ContinualMetaConfig() 

256 config.maml = MAMLConfig() 

257 config.evaluation = EvaluationConfig() 

258 

259 # Enable specific solutions based on user selection 

260 for solution in solutions: 

261 if solution == "process_reward_model": 

262 config.test_time_compute.use_process_reward = True 

263 config.test_time_compute.use_process_reward_model = True 

264 

265 elif solution == "consistency_verification": 

266 config.test_time_compute.use_test_time_training = True 

267 config.test_time_compute.adaptation_weight = 0.6 

268 

269 elif solution == "gradient_verification": 

270 config.test_time_compute.use_gradient_verification = True 

271 

272 elif solution == "attention_reasoning": 

273 config.test_time_compute.use_chain_of_thought = True 

274 config.test_time_compute.cot_method = "attention_based" 

275 

276 elif solution == "feature_reasoning": 

277 config.test_time_compute.use_chain_of_thought = True 

278 config.test_time_compute.cot_method = "feature_based" 

279 

280 elif solution == "prototype_reasoning": 

281 config.test_time_compute.use_chain_of_thought = True 

282 config.test_time_compute.cot_method = "prototype_based" 

283 

284 elif solution == "uncertainty_distances": 

285 config.prototypical.use_uncertainty_aware_distances = True 

286 

287 elif solution == "hierarchical_prototypes": 

288 config.prototypical.use_hierarchical_prototypes = True 

289 

290 elif solution == "task_adaptive_prototypes": 

291 config.prototypical.use_task_adaptive_prototypes = True 

292 

293 elif solution == "full_fisher": 

294 config.continual_meta.ewc_method = "full" 

295 config.continual_meta.fisher_estimation_method = "exact" 

296 

297 elif solution == "evcl": 

298 config.continual_meta.ewc_method = "evcl" 

299 

300 elif solution == "kfac_fisher": 

301 config.continual_meta.fisher_estimation_method = "kfac" 

302 

303 elif solution == "functional_forward": 

304 config.maml.functional_forward_method = "higher_style" 

305 

306 elif solution == "difficulty_estimation": 

307 config.task = TaskConfiguration(difficulty_estimation_method="entropy") 

308 

309 elif solution == "bootstrap_ci": 

310 config.evaluation.confidence_interval_method = "bca_bootstrap" 

311 

312 else: 

313 print(f"Warning: Unknown solution '{solution}'. Ignoring.") 

314 

315 return config 

316 

317 

318def create_modular_config( 

319 test_time_compute: Optional[str] = None, 

320 few_shot_method: Optional[str] = None, 

321 continual_method: Optional[str] = None, 

322 maml_variant: Optional[str] = None, 

323 evaluation_method: Optional[str] = None 

324) -> ComprehensiveMetaLearningConfig: 

325 """ 

326 Create modular configuration by choosing specific methods for each component. 

327  

328 Args: 

329 test_time_compute: "basic", "snell2024", "akyurek2024", "openai_o1", "hybrid" 

330 few_shot_method: "prototypical", "matching", "relation" 

331 continual_method: "ewc", "mas", "packnet", "hat" 

332 maml_variant: "maml", "fomaml", "reptile", "anil", "boil" 

333 evaluation_method: "bootstrap", "t_distribution", "bca_bootstrap" 

334 """ 

335 config = ComprehensiveMetaLearningConfig() 

336 

337 # Configure test-time compute 

338 if test_time_compute: 

339 config.test_time_compute = TestTimeComputeConfig() 

340 config.test_time_compute.compute_strategy = test_time_compute 

341 

342 if test_time_compute in ["snell2024", "hybrid"]: 

343 config.test_time_compute.use_process_reward = True 

344 if test_time_compute in ["akyurek2024", "hybrid"]: 

345 config.test_time_compute.use_test_time_training = True 

346 if test_time_compute in ["openai_o1", "hybrid"]: 

347 config.test_time_compute.use_chain_of_thought = True 

348 

349 # Configure few-shot method 

350 if few_shot_method == "prototypical": 

351 config.prototypical = PrototypicalConfig() 

352 config.prototypical.protonet_variant = "research_accurate" 

353 elif few_shot_method == "matching": 

354 config.matching = MatchingConfig() 

355 config.matching.attention_mechanism = "scaled_dot_product" 

356 elif few_shot_method == "relation": 

357 config.relation = RelationConfig() 

358 config.relation.use_graph_neural_network = True 

359 

360 # Configure continual learning 

361 if continual_method: 

362 config.continual_meta = ContinualMetaConfig() 

363 config.continual_meta.memory_consolidation_method = continual_method 

364 

365 if continual_method == "ewc": 

366 config.continual_meta.ewc_method = "diagonal" 

367 elif continual_method == "mas": 

368 config.continual_meta.use_gradient_importance = True 

369 

370 # Configure MAML variant 

371 if maml_variant: 

372 config.maml = MAMLConfig() 

373 config.maml.maml_variant = maml_variant 

374 

375 if maml_variant == "fomaml": 

376 config.maml.first_order = True 

377 elif maml_variant in ["anil", "boil"]: 

378 config.maml.functional_forward_method = "l2l_style" 

379 

380 # Configure evaluation 

381 if evaluation_method: 

382 config.evaluation = EvaluationConfig() 

383 config.evaluation.confidence_interval_method = evaluation_method 

384 

385 return config 

386 

387 

388def create_educational_config() -> ComprehensiveMetaLearningConfig: 

389 """ 

390 Create configuration optimized for educational use and understanding. 

391  

392 EDUCATIONAL: Simplified but still research-accurate implementations. 

393 """ 

394 config = ComprehensiveMetaLearningConfig() 

395 

396 # Simple but working implementations 

397 config.test_time_compute = TestTimeComputeConfig() 

398 config.test_time_compute.compute_strategy = "basic" 

399 config.test_time_compute.max_compute_budget = 50 

400 

401 config.prototypical = PrototypicalConfig() 

402 config.prototypical.protonet_variant = "simple" 

403 

404 config.maml = MAMLConfig() 

405 config.maml.maml_variant = "maml" 

406 config.maml.inner_steps = 3 

407 

408 config.evaluation = EvaluationConfig() 

409 config.evaluation.confidence_interval_method = "t_distribution" 

410 config.evaluation.num_episodes = 100 

411 

412 return config 

413 

414 

415def get_available_solutions() -> Dict[str, List[str]]: 

416 """ 

417 Get a dictionary of all available FIXME solutions organized by module. 

418  

419 Returns: 

420 Dictionary mapping module names to lists of available solutions 

421 """ 

422 return { 

423 "test_time_compute": [ 

424 "process_reward_model", 

425 "consistency_verification", 

426 "gradient_verification", 

427 "attention_reasoning", 

428 "feature_reasoning", 

429 "prototype_reasoning" 

430 ], 

431 "few_shot_learning": [ 

432 "uncertainty_distances", 

433 "hierarchical_prototypes", 

434 "task_adaptive_prototypes", 

435 "research_accurate_original" 

436 ], 

437 "continual_meta_learning": [ 

438 "diagonal_fisher", 

439 "full_fisher", 

440 "kfac_fisher", 

441 "evcl", 

442 "gradient_importance" 

443 ], 

444 "maml_variants": [ 

445 "l2l_functional_forward", 

446 "higher_functional_forward", 

447 "manual_functional_forward", 

448 "compiled_functional_forward" 

449 ], 

450 "utils": [ 

451 "silhouette_difficulty", 

452 "entropy_difficulty", 

453 "knn_difficulty", 

454 "t_distribution_ci", 

455 "meta_learning_ci", 

456 "bca_bootstrap_ci" 

457 ] 

458 } 

459 

460 

461def print_solution_summary(): 

462 """Print a comprehensive summary of all available FIXME solutions.""" 

463 solutions = get_available_solutions() 

464 

465 print("🔧 Meta-Learning Package - All Available FIXME Solutions") 

466 print("=" * 70) 

467 print(f"Total: {sum(len(module_solutions) for module_solutions in solutions.values())} solutions across {len(solutions)} modules") 

468 

469 for module, module_solutions in solutions.items(): 

470 print(f"\n📦 {module.replace('_', ' ').title()}:") 

471 for i, solution in enumerate(module_solutions, 1): 

472 print(f" {i}. ✅ {solution.replace('_', ' ').title()}") 

473 

474 print(f"\n🏭 Factory Functions Available:") 

475 print(" • create_all_fixme_solutions_config() - Enable ALL solutions") 

476 print(" • create_research_accurate_config() - Research-first approach") 

477 print(" • create_performance_optimized_config() - Performance-first approach") 

478 print(" • create_specific_solution_config([solutions]) - Pick specific solutions") 

479 print(" • create_modular_config(...) - Mix and match by module") 

480 print(" • create_educational_config() - Simplified for learning") 

481 

482 

483# Configuration validation 

484def validate_config(config: ComprehensiveMetaLearningConfig) -> Dict[str, List[str]]: 

485 """ 

486 Validate configuration for potential conflicts or issues. 

487  

488 Returns: 

489 Dictionary with 'warnings' and 'errors' lists 

490 """ 

491 issues = {"warnings": [], "errors": []} 

492 

493 # Check for conflicting settings 

494 if config.test_time_compute and config.maml: 

495 if (config.test_time_compute.use_test_time_training and 

496 config.maml.maml_variant in ["anil", "boil"]): 

497 issues["warnings"].append( 

498 "Test-time training with ANIL/BOIL may have conflicting adaptation strategies" 

499 ) 

500 

501 # Check for performance implications 

502 if config.continual_meta and config.continual_meta.fisher_estimation_method == "exact": 

503 issues["warnings"].append( 

504 "Exact Fisher Information computation is very expensive - consider 'empirical' for large models" 

505 ) 

506 

507 # Check for research accuracy 

508 if (config.prototypical and 

509 config.prototypical.use_uncertainty_aware_distances and 

510 not config.prototypical.uncertainty_estimation): 

511 issues["warnings"].append( 

512 "Uncertainty-aware distances require uncertainty_estimation=True for best results" 

513 ) 

514 

515 return issues