Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/configurations.py: 58%
172 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
1"""
2Few-Shot Learning Configuration Classes
3=====================================
5Configuration dataclasses for all few-shot learning algorithms.
6Extracted from the original monolithic few_shot_learning.py.
7"""
9from dataclasses import dataclass
10from typing import List
13@dataclass
14class FewShotConfig:
15 """Base configuration for few-shot learning algorithms."""
16 embedding_dim: int = 512
17 num_classes: int = 5
18 num_support: int = 5
19 num_query: int = 15
20 temperature: float = 1.0
21 dropout: float = 0.1
24@dataclass
25class PrototypicalConfig(FewShotConfig):
26 """
27 Comprehensive Configuration for Prototypical Networks with ALL FIXME Solutions.
29 This configuration provides complete control over ALL implementation variants
30 and research extensions with automatic conflict resolution.
31 """
32 # =========================
33 # CORE FIXME SOLUTIONS
34 # =========================
36 # FIXME SOLUTION: Implementation variant selection
37 protonet_variant: str = "research_accurate" # "original", "research_accurate", "simple", "enhanced"
38 use_original_implementation: bool = False # Pure Snell et al. (2017) implementation
40 # FIXME SOLUTION: Distance computation options (Snell et al. 2017)
41 use_squared_euclidean: bool = True # True for research accuracy
42 distance_temperature: float = 1.0
43 distance_metric: str = "euclidean" # "euclidean", "cosine", "manhattan", "learned"
44 use_learned_distance: bool = False # Learn distance metric (Vinyals et al. 2016)
45 distance_combination: str = "single" # "single", "ensemble", "weighted_ensemble"
46 distance_weights: List[float] = None # For weighted ensemble of distances
48 # FIXME SOLUTION: Prototype computation options
49 prototype_method: str = "mean" # "mean", "weighted_mean", "median", "attention"
50 use_support_weighting: bool = False
51 prototype_refinement_method: str = "none" # "none", "gradient", "attention", "iterative"
52 refinement_learning_rate: float = 0.01
53 use_prototype_attention: bool = False # Attend over support examples
54 attention_temperature: float = 1.0
55 prototype_refinement_steps: int = 3
57 # FIXME SOLUTION: Uncertainty-aware distance metrics (Allen et al. 2019)
58 use_uncertainty_aware_distances: bool = False # "Prototypical Networks with Uncertainty"
59 uncertainty_temperature: float = 2.0
60 uncertainty_method: str = "none" # "none", "ensemble", "dropout", "evidential"
61 use_monte_carlo_dropout: bool = False
62 dropout_samples: int = 10
63 use_evidential_learning: bool = False # Sensoy et al. 2018
65 # FIXME SOLUTION: Hierarchical prototype structures (Rusu et al. 2019)
66 use_hierarchical_prototypes: bool = False # "Meta-Learning with Latent Embedding"
67 hierarchy_levels: int = 2
68 hierarchical_aggregation: str = "bottom_up" # "bottom_up", "top_down", "bidirectional"
70 # FIXME SOLUTION: Task-specific prototype initialization (Finn et al. 2018)
71 use_task_adaptive_prototypes: bool = False # "Meta-Learning for Semi-Supervised Classification"
72 adaptation_steps: int = 5
73 adaptation_method: str = "none" # "none", "gradient", "ridge", "bayesian"
74 use_bayesian_prototypes: bool = False # Bayesian treatment of prototypes
75 prior_strength: float = 1.0
76 use_task_embeddings: bool = False # Learn task-specific embeddings
77 task_embedding_dim: int = 64
79 # =========================
80 # COMPREHENSIVE EXTENSIONS
81 # =========================
83 # Multi-scale and feature aggregation combinations
84 multi_scale_features: bool = True
85 scale_factors: List[int] = None
86 feature_aggregation_method: str = "concat" # "concat", "sum", "attention", "learned"
87 use_feature_pyramid: bool = False # Build feature pyramid
88 pyramid_levels: int = 4
90 # Compositional and hierarchical options
91 use_compositional_prototypes: bool = False # Compositional few-shot learning
92 composition_method: str = "addition" # "addition", "concatenation", "attention"
94 # Memory and continual learning options
95 use_memory_bank: bool = False # Store prototype memory
96 memory_bank_size: int = 1000
97 memory_update_strategy: str = "fifo" # "fifo", "random", "importance"
98 use_episodic_memory: bool = False
100 # Cross-modal and multi-modal support
101 use_cross_modal: bool = False # Cross-modal few-shot learning
102 modality_fusion: str = "early" # "early", "late", "attention"
103 modal_alignment: bool = False
105 # =========================
106 # EVALUATION & COMPARISON
107 # =========================
109 # Standard evaluation protocols
110 use_standard_evaluation: bool = True
111 num_episodes: int = 600 # Standard in literature
112 confidence_interval_method: str = "t_distribution" # "bootstrap", "t_distribution"
113 evaluation_metrics: List[str] = None # ["accuracy", "f1", "auc", "calibration"]
114 use_confidence_intervals: bool = True
115 bootstrap_samples: int = 1000
116 statistical_test: str = "paired_t" # "paired_t", "wilcoxon", "permutation"
118 # Comparison with existing libraries
119 compare_with_libraries: bool = False
120 library_comparison: List[str] = None # ["learn2learn", "torchmeta"]
122 # =========================
123 # ADVANCED RESEARCH OPTIONS
124 # =========================
126 # Advanced features (research-backed only)
127 enable_research_extensions: bool = False
128 research_extension_year: str = "2017" # Only enable extensions with citations
130 # Regularization and optimization
131 prototype_regularization: float = 0.001
132 diversity_weight: float = 0.1
133 consistency_weight: float = 0.05
134 adaptive_prototypes: bool = True
136 # Implementation debugging and analysis
137 debug_mode: bool = False
138 log_intermediate_results: bool = False
139 save_prototypes: bool = False
140 prototype_analysis: bool = False # Analyze prototype quality
142 def __post_init__(self):
143 """Initialize defaults and resolve configuration conflicts."""
144 # Set default values for lists
145 if self.scale_factors is None: 145 ↛ 147line 145 didn't jump to line 147 because the condition on line 145 was always true
146 self.scale_factors = [1, 2, 4, 8]
147 if self.library_comparison is None: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true
148 self.library_comparison = []
149 if self.distance_weights is None: 149 ↛ 151line 149 didn't jump to line 151 because the condition on line 149 was always true
150 self.distance_weights = [1.0] # Default single weight
151 if self.evaluation_metrics is None: 151 ↛ 155line 151 didn't jump to line 155 because the condition on line 151 was always true
152 self.evaluation_metrics = ["accuracy"]
154 # Automatic conflict resolution for overlapping solutions
155 self._resolve_configuration_conflicts()
157 def _resolve_configuration_conflicts(self):
158 """
159 FIXME SOLUTION: Automatic conflict resolution between overlapping options.
161 Resolves conflicts and ensures compatible combinations of features
162 to prevent implementation errors and provide clear user guidance.
163 """
165 # If using original implementation, disable all extensions
166 if self.use_original_implementation: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true
167 self.multi_scale_features = False
168 self.adaptive_prototypes = False
169 self.use_uncertainty_aware_distances = False
170 self.use_hierarchical_prototypes = False
171 self.use_task_adaptive_prototypes = False
172 self.use_compositional_prototypes = False
173 self.use_memory_bank = False
174 self.use_cross_modal = False
175 print("INFO: Original implementation selected - disabling all extensions for research accuracy")
177 # If uncertainty is enabled, ensure compatible distance metrics
178 if self.use_uncertainty_aware_distances and self.distance_metric == "learned": 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 print("WARNING: Uncertainty-aware distances may conflict with learned distance - using euclidean")
180 self.distance_metric = "euclidean"
182 # If hierarchical prototypes enabled, ensure compatible aggregation
183 if self.use_hierarchical_prototypes and self.feature_aggregation_method == "attention": 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true
184 print("INFO: Hierarchical prototypes detected - using hierarchical attention aggregation")
185 self.hierarchical_aggregation = "bidirectional"
187 # Ensure task adaptation and hierarchical don't conflict
188 if self.use_task_adaptive_prototypes and self.use_hierarchical_prototypes: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true
189 print("INFO: Both task adaptation and hierarchical enabled - using combined approach")
190 self.adaptation_method = "gradient"
192 # If using compositional prototypes, enable attention-based aggregation
193 if self.use_compositional_prototypes: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true
194 if self.feature_aggregation_method == "concat":
195 print("INFO: Compositional prototypes enabled - switching to attention aggregation")
196 self.feature_aggregation_method = "attention"
198 # If cross-modal is enabled, ensure proper fusion settings
199 if self.use_cross_modal and self.modality_fusion == "early": 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 print("INFO: Cross-modal learning - using early fusion with modal alignment")
201 self.modal_alignment = True
203 # Memory bank size validation
204 if self.use_memory_bank and self.memory_bank_size < 100: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true
205 print("WARNING: Memory bank size too small, increasing to 100")
206 self.memory_bank_size = 100
208 # Evaluation configuration validation
209 if self.num_episodes < 100: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true
210 print("WARNING: Too few evaluation episodes for statistical significance, increasing to 600")
211 self.num_episodes = 600
213 def get_variant_description(self) -> str:
214 """Get human-readable description of current configuration variant."""
215 if self.use_original_implementation:
216 return "Pure Research-Accurate (Snell et al. 2017)"
217 elif self.protonet_variant == "simple":
218 return "Simple Educational Variant"
219 elif self.protonet_variant == "research_accurate":
220 extensions = []
221 if self.use_uncertainty_aware_distances:
222 extensions.append("Uncertainty-Aware")
223 if self.use_hierarchical_prototypes:
224 extensions.append("Hierarchical")
225 if self.use_task_adaptive_prototypes:
226 extensions.append("Task-Adaptive")
227 if self.use_compositional_prototypes:
228 extensions.append("Compositional")
230 if extensions:
231 return f"Research-Accurate + {', '.join(extensions)}"
232 else:
233 return "Research-Accurate (Base)"
234 else:
235 return "Enhanced with All Extensions"
237 def validate_configuration(self) -> List[str]:
238 """
239 FIXME SOLUTION: Configuration validation with helpful error messages.
241 Returns list of validation warnings/errors for user review.
242 """
243 warnings = []
245 # Check for conflicting distance metrics
246 if self.distance_combination == "ensemble" and len(self.distance_weights) == 1:
247 warnings.append("Ensemble distance selected but only one weight provided")
249 # Check research extension compatibility
250 if self.enable_research_extensions and self.research_extension_year == "2017":
251 if any([self.use_uncertainty_aware_distances, self.use_hierarchical_prototypes]):
252 warnings.append("Research extensions from later years enabled with 2017 base")
254 # Check evaluation configuration
255 if self.use_confidence_intervals and self.num_episodes < 100:
256 warnings.append("Too few episodes for reliable confidence intervals")
258 # Check memory usage implications
259 if self.use_memory_bank and self.use_hierarchical_prototypes and self.use_compositional_prototypes:
260 warnings.append("Multiple memory-intensive features enabled - may impact performance")
262 return warnings
265@dataclass
266class MatchingConfig(FewShotConfig):
267 """Configuration for Matching Networks with 2024 improvements."""
268 # Original parameters
269 attention_type: str = "cosine" # cosine, bilinear, additive, scaled_dot_product
270 use_lstm: bool = True
271 lstm_layers: int = 2
272 bidirectional: bool = True
274 # 2024 enhancements
275 multi_head_attention: bool = True
276 num_attention_heads: int = 8
277 use_positional_encoding: bool = True
278 graph_attention: bool = True
281@dataclass
282class RelationConfig(FewShotConfig):
283 """Configuration for Relation Networks with graph neural enhancements."""
284 # Original parameters
285 relation_dim: int = 8
286 hidden_dim: int = 512
288 # 2024 Graph Neural Network improvements
289 use_graph_relations: bool = True
290 graph_layers: int = 3
291 edge_feature_dim: int = 64
292 node_feature_dim: int = 256
293 graph_attention_heads: int = 4
295 # Advanced relation modeling
296 relation_aggregation: str = "attention" # mean, max, attention, graph
297 compositional_relations: bool = True
298 temporal_relations: bool = False