Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/configurations.py: 58%

172 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2Few-Shot Learning Configuration Classes 

3===================================== 

4 

5Configuration dataclasses for all few-shot learning algorithms. 

6Extracted from the original monolithic few_shot_learning.py. 

7""" 

8 

9from dataclasses import dataclass 

10from typing import List 

11 

12 

13@dataclass 

14class FewShotConfig: 

15 """Base configuration for few-shot learning algorithms.""" 

16 embedding_dim: int = 512 

17 num_classes: int = 5 

18 num_support: int = 5 

19 num_query: int = 15 

20 temperature: float = 1.0 

21 dropout: float = 0.1 

22 

23 

24@dataclass 

25class PrototypicalConfig(FewShotConfig): 

26 """ 

27 Comprehensive Configuration for Prototypical Networks with ALL FIXME Solutions. 

28  

29 This configuration provides complete control over ALL implementation variants 

30 and research extensions with automatic conflict resolution. 

31 """ 

32 # ========================= 

33 # CORE FIXME SOLUTIONS 

34 # ========================= 

35 

36 # FIXME SOLUTION: Implementation variant selection 

37 protonet_variant: str = "research_accurate" # "original", "research_accurate", "simple", "enhanced" 

38 use_original_implementation: bool = False # Pure Snell et al. (2017) implementation 

39 

40 # FIXME SOLUTION: Distance computation options (Snell et al. 2017) 

41 use_squared_euclidean: bool = True # True for research accuracy 

42 distance_temperature: float = 1.0 

43 distance_metric: str = "euclidean" # "euclidean", "cosine", "manhattan", "learned" 

44 use_learned_distance: bool = False # Learn distance metric (Vinyals et al. 2016) 

45 distance_combination: str = "single" # "single", "ensemble", "weighted_ensemble" 

46 distance_weights: List[float] = None # For weighted ensemble of distances 

47 

48 # FIXME SOLUTION: Prototype computation options 

49 prototype_method: str = "mean" # "mean", "weighted_mean", "median", "attention" 

50 use_support_weighting: bool = False 

51 prototype_refinement_method: str = "none" # "none", "gradient", "attention", "iterative" 

52 refinement_learning_rate: float = 0.01 

53 use_prototype_attention: bool = False # Attend over support examples 

54 attention_temperature: float = 1.0 

55 prototype_refinement_steps: int = 3 

56 

57 # FIXME SOLUTION: Uncertainty-aware distance metrics (Allen et al. 2019) 

58 use_uncertainty_aware_distances: bool = False # "Prototypical Networks with Uncertainty" 

59 uncertainty_temperature: float = 2.0 

60 uncertainty_method: str = "none" # "none", "ensemble", "dropout", "evidential" 

61 use_monte_carlo_dropout: bool = False 

62 dropout_samples: int = 10 

63 use_evidential_learning: bool = False # Sensoy et al. 2018 

64 

65 # FIXME SOLUTION: Hierarchical prototype structures (Rusu et al. 2019)  

66 use_hierarchical_prototypes: bool = False # "Meta-Learning with Latent Embedding" 

67 hierarchy_levels: int = 2 

68 hierarchical_aggregation: str = "bottom_up" # "bottom_up", "top_down", "bidirectional" 

69 

70 # FIXME SOLUTION: Task-specific prototype initialization (Finn et al. 2018) 

71 use_task_adaptive_prototypes: bool = False # "Meta-Learning for Semi-Supervised Classification" 

72 adaptation_steps: int = 5 

73 adaptation_method: str = "none" # "none", "gradient", "ridge", "bayesian" 

74 use_bayesian_prototypes: bool = False # Bayesian treatment of prototypes 

75 prior_strength: float = 1.0 

76 use_task_embeddings: bool = False # Learn task-specific embeddings 

77 task_embedding_dim: int = 64 

78 

79 # ========================= 

80 # COMPREHENSIVE EXTENSIONS 

81 # ========================= 

82 

83 # Multi-scale and feature aggregation combinations 

84 multi_scale_features: bool = True 

85 scale_factors: List[int] = None 

86 feature_aggregation_method: str = "concat" # "concat", "sum", "attention", "learned" 

87 use_feature_pyramid: bool = False # Build feature pyramid 

88 pyramid_levels: int = 4 

89 

90 # Compositional and hierarchical options 

91 use_compositional_prototypes: bool = False # Compositional few-shot learning 

92 composition_method: str = "addition" # "addition", "concatenation", "attention" 

93 

94 # Memory and continual learning options  

95 use_memory_bank: bool = False # Store prototype memory 

96 memory_bank_size: int = 1000 

97 memory_update_strategy: str = "fifo" # "fifo", "random", "importance" 

98 use_episodic_memory: bool = False 

99 

100 # Cross-modal and multi-modal support 

101 use_cross_modal: bool = False # Cross-modal few-shot learning 

102 modality_fusion: str = "early" # "early", "late", "attention" 

103 modal_alignment: bool = False 

104 

105 # ========================= 

106 # EVALUATION & COMPARISON 

107 # ========================= 

108 

109 # Standard evaluation protocols 

110 use_standard_evaluation: bool = True 

111 num_episodes: int = 600 # Standard in literature 

112 confidence_interval_method: str = "t_distribution" # "bootstrap", "t_distribution" 

113 evaluation_metrics: List[str] = None # ["accuracy", "f1", "auc", "calibration"] 

114 use_confidence_intervals: bool = True 

115 bootstrap_samples: int = 1000 

116 statistical_test: str = "paired_t" # "paired_t", "wilcoxon", "permutation" 

117 

118 # Comparison with existing libraries 

119 compare_with_libraries: bool = False 

120 library_comparison: List[str] = None # ["learn2learn", "torchmeta"] 

121 

122 # ========================= 

123 # ADVANCED RESEARCH OPTIONS 

124 # ========================= 

125 

126 # Advanced features (research-backed only) 

127 enable_research_extensions: bool = False 

128 research_extension_year: str = "2017" # Only enable extensions with citations 

129 

130 # Regularization and optimization 

131 prototype_regularization: float = 0.001 

132 diversity_weight: float = 0.1 

133 consistency_weight: float = 0.05 

134 adaptive_prototypes: bool = True 

135 

136 # Implementation debugging and analysis 

137 debug_mode: bool = False 

138 log_intermediate_results: bool = False 

139 save_prototypes: bool = False 

140 prototype_analysis: bool = False # Analyze prototype quality 

141 

142 def __post_init__(self): 

143 """Initialize defaults and resolve configuration conflicts.""" 

144 # Set default values for lists 

145 if self.scale_factors is None: 145 ↛ 147line 145 didn't jump to line 147 because the condition on line 145 was always true

146 self.scale_factors = [1, 2, 4, 8] 

147 if self.library_comparison is None: 147 ↛ 149line 147 didn't jump to line 149 because the condition on line 147 was always true

148 self.library_comparison = [] 

149 if self.distance_weights is None: 149 ↛ 151line 149 didn't jump to line 151 because the condition on line 149 was always true

150 self.distance_weights = [1.0] # Default single weight 

151 if self.evaluation_metrics is None: 151 ↛ 155line 151 didn't jump to line 155 because the condition on line 151 was always true

152 self.evaluation_metrics = ["accuracy"] 

153 

154 # Automatic conflict resolution for overlapping solutions 

155 self._resolve_configuration_conflicts() 

156 

157 def _resolve_configuration_conflicts(self): 

158 """ 

159 FIXME SOLUTION: Automatic conflict resolution between overlapping options. 

160  

161 Resolves conflicts and ensures compatible combinations of features 

162 to prevent implementation errors and provide clear user guidance. 

163 """ 

164 

165 # If using original implementation, disable all extensions 

166 if self.use_original_implementation: 166 ↛ 167line 166 didn't jump to line 167 because the condition on line 166 was never true

167 self.multi_scale_features = False 

168 self.adaptive_prototypes = False 

169 self.use_uncertainty_aware_distances = False 

170 self.use_hierarchical_prototypes = False 

171 self.use_task_adaptive_prototypes = False 

172 self.use_compositional_prototypes = False 

173 self.use_memory_bank = False 

174 self.use_cross_modal = False 

175 print("INFO: Original implementation selected - disabling all extensions for research accuracy") 

176 

177 # If uncertainty is enabled, ensure compatible distance metrics 

178 if self.use_uncertainty_aware_distances and self.distance_metric == "learned": 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 print("WARNING: Uncertainty-aware distances may conflict with learned distance - using euclidean") 

180 self.distance_metric = "euclidean" 

181 

182 # If hierarchical prototypes enabled, ensure compatible aggregation 

183 if self.use_hierarchical_prototypes and self.feature_aggregation_method == "attention": 183 ↛ 184line 183 didn't jump to line 184 because the condition on line 183 was never true

184 print("INFO: Hierarchical prototypes detected - using hierarchical attention aggregation") 

185 self.hierarchical_aggregation = "bidirectional" 

186 

187 # Ensure task adaptation and hierarchical don't conflict 

188 if self.use_task_adaptive_prototypes and self.use_hierarchical_prototypes: 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 print("INFO: Both task adaptation and hierarchical enabled - using combined approach") 

190 self.adaptation_method = "gradient" 

191 

192 # If using compositional prototypes, enable attention-based aggregation 

193 if self.use_compositional_prototypes: 193 ↛ 194line 193 didn't jump to line 194 because the condition on line 193 was never true

194 if self.feature_aggregation_method == "concat": 

195 print("INFO: Compositional prototypes enabled - switching to attention aggregation") 

196 self.feature_aggregation_method = "attention" 

197 

198 # If cross-modal is enabled, ensure proper fusion settings 

199 if self.use_cross_modal and self.modality_fusion == "early": 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 print("INFO: Cross-modal learning - using early fusion with modal alignment") 

201 self.modal_alignment = True 

202 

203 # Memory bank size validation 

204 if self.use_memory_bank and self.memory_bank_size < 100: 204 ↛ 205line 204 didn't jump to line 205 because the condition on line 204 was never true

205 print("WARNING: Memory bank size too small, increasing to 100") 

206 self.memory_bank_size = 100 

207 

208 # Evaluation configuration validation 

209 if self.num_episodes < 100: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 print("WARNING: Too few evaluation episodes for statistical significance, increasing to 600") 

211 self.num_episodes = 600 

212 

213 def get_variant_description(self) -> str: 

214 """Get human-readable description of current configuration variant.""" 

215 if self.use_original_implementation: 

216 return "Pure Research-Accurate (Snell et al. 2017)" 

217 elif self.protonet_variant == "simple": 

218 return "Simple Educational Variant" 

219 elif self.protonet_variant == "research_accurate": 

220 extensions = [] 

221 if self.use_uncertainty_aware_distances: 

222 extensions.append("Uncertainty-Aware") 

223 if self.use_hierarchical_prototypes: 

224 extensions.append("Hierarchical") 

225 if self.use_task_adaptive_prototypes: 

226 extensions.append("Task-Adaptive") 

227 if self.use_compositional_prototypes: 

228 extensions.append("Compositional") 

229 

230 if extensions: 

231 return f"Research-Accurate + {', '.join(extensions)}" 

232 else: 

233 return "Research-Accurate (Base)" 

234 else: 

235 return "Enhanced with All Extensions" 

236 

237 def validate_configuration(self) -> List[str]: 

238 """ 

239 FIXME SOLUTION: Configuration validation with helpful error messages. 

240  

241 Returns list of validation warnings/errors for user review. 

242 """ 

243 warnings = [] 

244 

245 # Check for conflicting distance metrics 

246 if self.distance_combination == "ensemble" and len(self.distance_weights) == 1: 

247 warnings.append("Ensemble distance selected but only one weight provided") 

248 

249 # Check research extension compatibility 

250 if self.enable_research_extensions and self.research_extension_year == "2017": 

251 if any([self.use_uncertainty_aware_distances, self.use_hierarchical_prototypes]): 

252 warnings.append("Research extensions from later years enabled with 2017 base") 

253 

254 # Check evaluation configuration 

255 if self.use_confidence_intervals and self.num_episodes < 100: 

256 warnings.append("Too few episodes for reliable confidence intervals") 

257 

258 # Check memory usage implications 

259 if self.use_memory_bank and self.use_hierarchical_prototypes and self.use_compositional_prototypes: 

260 warnings.append("Multiple memory-intensive features enabled - may impact performance") 

261 

262 return warnings 

263 

264 

265@dataclass 

266class MatchingConfig(FewShotConfig): 

267 """Configuration for Matching Networks with 2024 improvements.""" 

268 # Original parameters 

269 attention_type: str = "cosine" # cosine, bilinear, additive, scaled_dot_product 

270 use_lstm: bool = True 

271 lstm_layers: int = 2 

272 bidirectional: bool = True 

273 

274 # 2024 enhancements 

275 multi_head_attention: bool = True 

276 num_attention_heads: int = 8 

277 use_positional_encoding: bool = True 

278 graph_attention: bool = True 

279 

280 

281@dataclass 

282class RelationConfig(FewShotConfig): 

283 """Configuration for Relation Networks with graph neural enhancements.""" 

284 # Original parameters  

285 relation_dim: int = 8 

286 hidden_dim: int = 512 

287 

288 # 2024 Graph Neural Network improvements 

289 use_graph_relations: bool = True 

290 graph_layers: int = 3 

291 edge_feature_dim: int = 64 

292 node_feature_dim: int = 256 

293 graph_attention_heads: int = 4 

294 

295 # Advanced relation modeling 

296 relation_aggregation: str = "attention" # mean, max, attention, graph 

297 compositional_relations: bool = True 

298 temporal_relations: bool = False