Coverage for src/meta_learning/meta_learning_modules/config_factory.py: 9%
245 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:35 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:35 +0900
1#!/usr/bin/env python3
2"""
3Comprehensive Configuration Factory for ALL FIXME Solutions
4=========================================================
6This module provides factory functions to create configurations for ALL
7implemented FIXME solutions across all modules in the meta-learning package.
9Users can pick and choose which solutions to enable with overlapping
10configurations handled intelligently.
12All configurations are research-accurate and production-ready.
13"""
15from typing import Dict, List, Optional, Any, Union
16from dataclasses import dataclass
18# Import all configuration classes
19from .test_time_compute import TestTimeComputeConfig
20from .few_shot_learning import PrototypicalConfig, MatchingConfig, RelationConfig
21from .continual_meta_learning import ContinualMetaConfig, OnlineMetaConfig
22from .maml_variants import MAMLConfig
23from .utils import TaskConfiguration, EvaluationConfig
26@dataclass
27class ComprehensiveMetaLearningConfig:
28 """
29 Master configuration class that encompasses ALL FIXME solutions.
31 Users can configure every aspect of the meta-learning pipeline from
32 a single unified configuration object.
33 """
34 # Test-Time Compute Configuration
35 test_time_compute: Optional[TestTimeComputeConfig] = None
37 # Few-Shot Learning Configurations
38 prototypical: Optional[PrototypicalConfig] = None
39 matching: Optional[MatchingConfig] = None
40 relation: Optional[RelationConfig] = None
42 # Continual Learning Configurations
43 continual_meta: Optional[ContinualMetaConfig] = None
44 online_meta: Optional[OnlineMetaConfig] = None
46 # MAML Configuration
47 maml: Optional[MAMLConfig] = None
49 # Utility Configurations
50 task: Optional[TaskConfiguration] = None
51 evaluation: Optional[EvaluationConfig] = None
53 # Global settings that affect multiple modules
54 global_seed: int = 42
55 device: str = "auto" # "auto", "cpu", "cuda"
56 verbose: bool = True
59# =============================================================================
60# COMPREHENSIVE FACTORY FUNCTIONS FOR ALL FIXME SOLUTIONS
61# =============================================================================
63def create_all_fixme_solutions_config() -> ComprehensiveMetaLearningConfig:
64 """
65 Create configuration that enables ALL implemented FIXME solutions.
67 COMPREHENSIVE: Every single FIXME solution across all modules enabled
68 with balanced settings for research accuracy and performance.
69 """
70 config = ComprehensiveMetaLearningConfig()
72 # Test-Time Compute: All solutions enabled
73 config.test_time_compute = TestTimeComputeConfig()
74 config.test_time_compute.compute_strategy = "hybrid"
75 config.test_time_compute.use_process_reward = True
76 config.test_time_compute.use_test_time_training = True
77 config.test_time_compute.use_gradient_verification = True
78 config.test_time_compute.use_chain_of_thought = True
79 config.test_time_compute.cot_method = "attention_based"
80 config.test_time_compute.use_optimal_allocation = True
81 config.test_time_compute.use_adaptive_distribution = True
83 # Prototypical Networks: All FIXME solutions enabled
84 config.prototypical = PrototypicalConfig()
85 config.prototypical.use_uncertainty_aware_distances = True
86 config.prototypical.use_hierarchical_prototypes = True
87 config.prototypical.use_task_adaptive_prototypes = True
88 config.prototypical.protonet_variant = "research_accurate"
89 config.prototypical.multi_scale_features = True
90 config.prototypical.adaptive_prototypes = True
91 config.prototypical.uncertainty_estimation = True
93 # Matching Networks: Advanced attention mechanisms
94 config.matching = MatchingConfig()
95 config.matching.attention_mechanism = "scaled_dot_product"
96 config.matching.context_encoding = True
97 config.matching.support_set_encoding = "transformer"
98 config.matching.bidirectional_lstm = True
100 # Relation Networks: Graph neural networks enabled
101 config.relation = RelationConfig()
102 config.relation.use_graph_neural_network = True
103 config.relation.edge_features = True
104 config.relation.self_attention = True
106 # Continual Learning: All EWC and Fisher solutions
107 config.continual_meta = ContinualMetaConfig()
108 config.continual_meta.ewc_method = "full" # Use full Fisher matrix
109 config.continual_meta.fisher_estimation_method = "exact"
110 config.continual_meta.fisher_accumulation_method = "ema"
111 config.continual_meta.memory_consolidation_method = "ewc"
112 config.continual_meta.use_task_specific_importance = True
113 config.continual_meta.use_gradient_importance = True
115 # Online Meta-Learning: Advanced replay and adaptation
116 config.online_meta = OnlineMetaConfig()
117 config.online_meta.experience_replay = True
118 config.online_meta.prioritized_replay = True
119 config.online_meta.importance_sampling = True
120 config.online_meta.adaptive_lr = True
122 # MAML: All functional forward solutions
123 config.maml = MAMLConfig()
124 config.maml.functional_forward_method = "higher_style"
125 config.maml.maml_variant = "maml"
126 config.maml.inner_lr = 0.01
127 config.maml.inner_steps = 5
128 config.maml.first_order = False
130 # Task Configuration: All difficulty estimation methods
131 config.task = TaskConfiguration()
132 config.task.difficulty_estimation_method = "entropy" # Can switch between methods
134 # Evaluation: All confidence interval methods
135 config.evaluation = EvaluationConfig()
136 config.evaluation.confidence_interval_method = "bca_bootstrap"
137 config.evaluation.num_episodes = 600
139 return config
142def create_research_accurate_config() -> ComprehensiveMetaLearningConfig:
143 """
144 Create configuration focused on research accuracy over performance.
146 RESEARCH-FIRST: Prioritizes exact implementations from papers over speed.
147 """
148 config = ComprehensiveMetaLearningConfig()
150 # Test-Time Compute: Research-accurate methods
151 config.test_time_compute = TestTimeComputeConfig()
152 config.test_time_compute.compute_strategy = "snell2024"
153 config.test_time_compute.use_process_reward = True
154 config.test_time_compute.prm_scoring_method = "weighted"
156 # Prototypical Networks: Pure original implementation
157 config.prototypical = PrototypicalConfig()
158 config.prototypical.use_original_implementation = True
159 config.prototypical.use_squared_euclidean = True
160 config.prototypical.prototype_method = "mean"
162 # Continual Learning: Kirkpatrick et al. 2017 exact
163 config.continual_meta = ContinualMetaConfig()
164 config.continual_meta.ewc_method = "diagonal"
165 config.continual_meta.fisher_estimation_method = "empirical"
166 config.continual_meta.fisher_sampling_method = "true_posterior"
168 # MAML: Original Finn et al. 2017 implementation
169 config.maml = MAMLConfig()
170 config.maml.maml_variant = "maml"
171 config.maml.functional_forward_method = "basic"
172 config.maml.first_order = False
174 # Evaluation: Standard meta-learning protocols
175 config.evaluation = EvaluationConfig()
176 config.evaluation.confidence_interval_method = "t_distribution"
177 config.evaluation.num_episodes = 600
179 return config
182def create_performance_optimized_config() -> ComprehensiveMetaLearningConfig:
183 """
184 Create configuration optimized for performance and speed.
186 PERFORMANCE-FIRST: Balanced accuracy with computational efficiency.
187 """
188 config = ComprehensiveMetaLearningConfig()
190 # Test-Time Compute: Fast configuration
191 config.test_time_compute = TestTimeComputeConfig()
192 config.test_time_compute.compute_strategy = "basic"
193 config.test_time_compute.max_compute_budget = 100
194 config.test_time_compute.min_compute_steps = 3
195 config.test_time_compute.use_chain_of_thought = True
196 config.test_time_compute.cot_method = "prototype_based" # Fastest method
198 # Prototypical Networks: Simple but effective
199 config.prototypical = PrototypicalConfig()
200 config.prototypical.protonet_variant = "simple"
201 config.prototypical.multi_scale_features = False
202 config.prototypical.adaptive_prototypes = False
204 # Continual Learning: Diagonal EWC for speed
205 config.continual_meta = ContinualMetaConfig()
206 config.continual_meta.ewc_method = "diagonal"
207 config.continual_meta.fisher_estimation_method = "empirical"
208 config.continual_meta.fisher_accumulation_method = "sum"
210 # MAML: First-order for speed
211 config.maml = MAMLConfig()
212 config.maml.maml_variant = "fomaml" # First-order MAML
213 config.maml.functional_forward_method = "compiled" # PyTorch 2.0 optimization
214 config.maml.inner_steps = 3
216 # Evaluation: Fast CI computation
217 config.evaluation = EvaluationConfig()
218 config.evaluation.confidence_interval_method = "bootstrap"
219 config.evaluation.num_episodes = 300 # Reduced for speed
221 return config
224def create_specific_solution_config(
225 solutions: List[str]
226) -> ComprehensiveMetaLearningConfig:
227 """
228 Create configuration for specific FIXME solutions only.
230 Args:
231 solutions: List of solution identifiers to enable
233 Available solutions:
234 - "process_reward_model": Test-time compute process reward verification
235 - "consistency_verification": Test-time training consistency checks
236 - "gradient_verification": Gradient-based step verification
237 - "attention_reasoning": Attention-based reasoning paths
238 - "feature_reasoning": Feature-based reasoning decomposition
239 - "prototype_reasoning": Prototype-distance reasoning steps
240 - "uncertainty_distances": Uncertainty-aware distance metrics
241 - "hierarchical_prototypes": Multi-level prototype structures
242 - "task_adaptive_prototypes": Task-specific prototype initialization
243 - "full_fisher": Full Fisher Information Matrix computation
244 - "evcl": Elastic Variational Continual Learning
245 - "kfac_fisher": Kronecker-factored Fisher approximation
246 - "functional_forward": Advanced functional forward methods
247 - "difficulty_estimation": Advanced difficulty estimation methods
248 - "bootstrap_ci": Advanced confidence interval methods
249 """
250 config = ComprehensiveMetaLearningConfig()
252 # Initialize basic configurations
253 config.test_time_compute = TestTimeComputeConfig()
254 config.prototypical = PrototypicalConfig()
255 config.continual_meta = ContinualMetaConfig()
256 config.maml = MAMLConfig()
257 config.evaluation = EvaluationConfig()
259 # Enable specific solutions based on user selection
260 for solution in solutions:
261 if solution == "process_reward_model":
262 config.test_time_compute.use_process_reward = True
263 config.test_time_compute.use_process_reward_model = True
265 elif solution == "consistency_verification":
266 config.test_time_compute.use_test_time_training = True
267 config.test_time_compute.adaptation_weight = 0.6
269 elif solution == "gradient_verification":
270 config.test_time_compute.use_gradient_verification = True
272 elif solution == "attention_reasoning":
273 config.test_time_compute.use_chain_of_thought = True
274 config.test_time_compute.cot_method = "attention_based"
276 elif solution == "feature_reasoning":
277 config.test_time_compute.use_chain_of_thought = True
278 config.test_time_compute.cot_method = "feature_based"
280 elif solution == "prototype_reasoning":
281 config.test_time_compute.use_chain_of_thought = True
282 config.test_time_compute.cot_method = "prototype_based"
284 elif solution == "uncertainty_distances":
285 config.prototypical.use_uncertainty_aware_distances = True
287 elif solution == "hierarchical_prototypes":
288 config.prototypical.use_hierarchical_prototypes = True
290 elif solution == "task_adaptive_prototypes":
291 config.prototypical.use_task_adaptive_prototypes = True
293 elif solution == "full_fisher":
294 config.continual_meta.ewc_method = "full"
295 config.continual_meta.fisher_estimation_method = "exact"
297 elif solution == "evcl":
298 config.continual_meta.ewc_method = "evcl"
300 elif solution == "kfac_fisher":
301 config.continual_meta.fisher_estimation_method = "kfac"
303 elif solution == "functional_forward":
304 config.maml.functional_forward_method = "higher_style"
306 elif solution == "difficulty_estimation":
307 config.task = TaskConfiguration(difficulty_estimation_method="entropy")
309 elif solution == "bootstrap_ci":
310 config.evaluation.confidence_interval_method = "bca_bootstrap"
312 else:
313 print(f"Warning: Unknown solution '{solution}'. Ignoring.")
315 return config
318def create_modular_config(
319 test_time_compute: Optional[str] = None,
320 few_shot_method: Optional[str] = None,
321 continual_method: Optional[str] = None,
322 maml_variant: Optional[str] = None,
323 evaluation_method: Optional[str] = None
324) -> ComprehensiveMetaLearningConfig:
325 """
326 Create modular configuration by choosing specific methods for each component.
328 Args:
329 test_time_compute: "basic", "snell2024", "akyurek2024", "openai_o1", "hybrid"
330 few_shot_method: "prototypical", "matching", "relation"
331 continual_method: "ewc", "mas", "packnet", "hat"
332 maml_variant: "maml", "fomaml", "reptile", "anil", "boil"
333 evaluation_method: "bootstrap", "t_distribution", "bca_bootstrap"
334 """
335 config = ComprehensiveMetaLearningConfig()
337 # Configure test-time compute
338 if test_time_compute:
339 config.test_time_compute = TestTimeComputeConfig()
340 config.test_time_compute.compute_strategy = test_time_compute
342 if test_time_compute in ["snell2024", "hybrid"]:
343 config.test_time_compute.use_process_reward = True
344 if test_time_compute in ["akyurek2024", "hybrid"]:
345 config.test_time_compute.use_test_time_training = True
346 if test_time_compute in ["openai_o1", "hybrid"]:
347 config.test_time_compute.use_chain_of_thought = True
349 # Configure few-shot method
350 if few_shot_method == "prototypical":
351 config.prototypical = PrototypicalConfig()
352 config.prototypical.protonet_variant = "research_accurate"
353 elif few_shot_method == "matching":
354 config.matching = MatchingConfig()
355 config.matching.attention_mechanism = "scaled_dot_product"
356 elif few_shot_method == "relation":
357 config.relation = RelationConfig()
358 config.relation.use_graph_neural_network = True
360 # Configure continual learning
361 if continual_method:
362 config.continual_meta = ContinualMetaConfig()
363 config.continual_meta.memory_consolidation_method = continual_method
365 if continual_method == "ewc":
366 config.continual_meta.ewc_method = "diagonal"
367 elif continual_method == "mas":
368 config.continual_meta.use_gradient_importance = True
370 # Configure MAML variant
371 if maml_variant:
372 config.maml = MAMLConfig()
373 config.maml.maml_variant = maml_variant
375 if maml_variant == "fomaml":
376 config.maml.first_order = True
377 elif maml_variant in ["anil", "boil"]:
378 config.maml.functional_forward_method = "l2l_style"
380 # Configure evaluation
381 if evaluation_method:
382 config.evaluation = EvaluationConfig()
383 config.evaluation.confidence_interval_method = evaluation_method
385 return config
388def create_educational_config() -> ComprehensiveMetaLearningConfig:
389 """
390 Create configuration optimized for educational use and understanding.
392 EDUCATIONAL: Simplified but still research-accurate implementations.
393 """
394 config = ComprehensiveMetaLearningConfig()
396 # Simple but working implementations
397 config.test_time_compute = TestTimeComputeConfig()
398 config.test_time_compute.compute_strategy = "basic"
399 config.test_time_compute.max_compute_budget = 50
401 config.prototypical = PrototypicalConfig()
402 config.prototypical.protonet_variant = "simple"
404 config.maml = MAMLConfig()
405 config.maml.maml_variant = "maml"
406 config.maml.inner_steps = 3
408 config.evaluation = EvaluationConfig()
409 config.evaluation.confidence_interval_method = "t_distribution"
410 config.evaluation.num_episodes = 100
412 return config
415def get_available_solutions() -> Dict[str, List[str]]:
416 """
417 Get a dictionary of all available FIXME solutions organized by module.
419 Returns:
420 Dictionary mapping module names to lists of available solutions
421 """
422 return {
423 "test_time_compute": [
424 "process_reward_model",
425 "consistency_verification",
426 "gradient_verification",
427 "attention_reasoning",
428 "feature_reasoning",
429 "prototype_reasoning"
430 ],
431 "few_shot_learning": [
432 "uncertainty_distances",
433 "hierarchical_prototypes",
434 "task_adaptive_prototypes",
435 "research_accurate_original"
436 ],
437 "continual_meta_learning": [
438 "diagonal_fisher",
439 "full_fisher",
440 "kfac_fisher",
441 "evcl",
442 "gradient_importance"
443 ],
444 "maml_variants": [
445 "l2l_functional_forward",
446 "higher_functional_forward",
447 "manual_functional_forward",
448 "compiled_functional_forward"
449 ],
450 "utils": [
451 "silhouette_difficulty",
452 "entropy_difficulty",
453 "knn_difficulty",
454 "t_distribution_ci",
455 "meta_learning_ci",
456 "bca_bootstrap_ci"
457 ]
458 }
461def print_solution_summary():
462 """Print a comprehensive summary of all available FIXME solutions."""
463 solutions = get_available_solutions()
465 print("🔧 Meta-Learning Package - All Available FIXME Solutions")
466 print("=" * 70)
467 print(f"Total: {sum(len(module_solutions) for module_solutions in solutions.values())} solutions across {len(solutions)} modules")
469 for module, module_solutions in solutions.items():
470 print(f"\n📦 {module.replace('_', ' ').title()}:")
471 for i, solution in enumerate(module_solutions, 1):
472 print(f" {i}. ✅ {solution.replace('_', ' ').title()}")
474 print(f"\n🏭 Factory Functions Available:")
475 print(" • create_all_fixme_solutions_config() - Enable ALL solutions")
476 print(" • create_research_accurate_config() - Research-first approach")
477 print(" • create_performance_optimized_config() - Performance-first approach")
478 print(" • create_specific_solution_config([solutions]) - Pick specific solutions")
479 print(" • create_modular_config(...) - Mix and match by module")
480 print(" • create_educational_config() - Simplified for learning")
483# Configuration validation
484def validate_config(config: ComprehensiveMetaLearningConfig) -> Dict[str, List[str]]:
485 """
486 Validate configuration for potential conflicts or issues.
488 Returns:
489 Dictionary with 'warnings' and 'errors' lists
490 """
491 issues = {"warnings": [], "errors": []}
493 # Check for conflicting settings
494 if config.test_time_compute and config.maml:
495 if (config.test_time_compute.use_test_time_training and
496 config.maml.maml_variant in ["anil", "boil"]):
497 issues["warnings"].append(
498 "Test-time training with ANIL/BOIL may have conflicting adaptation strategies"
499 )
501 # Check for performance implications
502 if config.continual_meta and config.continual_meta.fisher_estimation_method == "exact":
503 issues["warnings"].append(
504 "Exact Fisher Information computation is very expensive - consider 'empirical' for large models"
505 )
507 # Check for research accuracy
508 if (config.prototypical and
509 config.prototypical.use_uncertainty_aware_distances and
510 not config.prototypical.uncertainty_estimation):
511 issues["warnings"].append(
512 "Uncertainty-aware distances require uncertainty_estimation=True for best results"
513 )
515 return issues