Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/core_networks.py: 27%

178 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2Few-Shot Learning Core Network Architectures 

3========================================== 

4 

5Core neural network implementations for few-shot learning algorithms. 

6Extracted from the original monolithic few_shot_learning.py. 

7""" 

8 

9import torch 

10import torch.nn as nn 

11import torch.nn.functional as F 

12from typing import Dict, List, Tuple, Optional, Any 

13import numpy as np 

14import logging 

15 

16from .configurations import FewShotConfig, PrototypicalConfig, MatchingConfig, RelationConfig 

17from .advanced_components import ( 

18 MultiScaleFeatureAggregator, PrototypeRefiner, UncertaintyEstimator, 

19 ScaledDotProductAttention, AdditiveAttention, BilinearAttention, 

20 GraphRelationModule, StandardRelationModule, 

21 UncertaintyAwareDistance, HierarchicalPrototypes, TaskAdaptivePrototypes 

22) 

23 

24logger = logging.getLogger(__name__) 

25 

26 

27class PrototypicalNetworks: 

28 """ 

29 Advanced Prototypical Networks with 2024 improvements. 

30  

31 Based on Snell et al. (2017) "Prototypical Networks for Few-shot Learning" 

32 with research-accurate extensions and configurable variants. 

33 """ 

34 

35 def __init__(self, backbone: nn.Module, config: PrototypicalConfig = None): 

36 """Initialize advanced Prototypical Networks.""" 

37 self.backbone = backbone 

38 self.config = config or PrototypicalConfig() 

39 

40 # Multi-scale feature aggregation 

41 if self.config.multi_scale_features: 41 ↛ 48line 41 didn't jump to line 48 because the condition on line 41 was always true

42 self.scale_aggregator = MultiScaleFeatureAggregator( 

43 self.config.embedding_dim, 

44 self.config.scale_factors 

45 ) 

46 

47 # Adaptive prototype refinement 

48 if self.config.adaptive_prototypes: 

49 self.prototype_refiner = PrototypeRefiner( 

50 self.config.embedding_dim, 

51 self.config.prototype_refinement_steps 

52 ) 

53 

54 # Uncertainty estimation 

55 if hasattr(self.config, 'uncertainty_estimation') and self.config.uncertainty_estimation: 

56 self.uncertainty_estimator = UncertaintyEstimator( 

57 self.config.embedding_dim 

58 ) 

59 

60 # Advanced components based on config 

61 if hasattr(self.config, 'use_uncertainty_aware_distances') and self.config.use_uncertainty_aware_distances: 

62 self.uncertainty_distance = UncertaintyAwareDistance( 

63 self.config.embedding_dim, 

64 getattr(self.config, 'uncertainty_temperature', 2.0) 

65 ) 

66 

67 if hasattr(self.config, 'use_hierarchical_prototypes') and self.config.use_hierarchical_prototypes: 

68 self.hierarchical_prototypes = HierarchicalPrototypes( 

69 self.config.embedding_dim, 

70 getattr(self.config, 'hierarchy_levels', 2) 

71 ) 

72 

73 if hasattr(self.config, 'use_task_adaptive_prototypes') and self.config.use_task_adaptive_prototypes: 

74 self.adaptive_initializer = TaskAdaptivePrototypes( 

75 self.config.embedding_dim, 

76 getattr(self.config, 'adaptation_steps', 5) 

77 ) 

78 

79 logger.info(f"Initialized Advanced Prototypical Networks: {self.config}") 

80 self._setup_implementation_variant() 

81 

82 def forward( 

83 self, 

84 support_x: torch.Tensor, 

85 support_y: torch.Tensor, 

86 query_x: torch.Tensor, 

87 return_uncertainty: bool = False 

88 ) -> Dict[str, torch.Tensor]: 

89 """ 

90 Configurable forward pass that routes to appropriate implementation. 

91 """ 

92 return self._forward_impl(support_x, support_y, query_x, return_uncertainty) 

93 

94 def _setup_implementation_variant(self): 

95 """Setup the appropriate implementation based on configuration.""" 

96 variant = getattr(self.config, 'protonet_variant', 'enhanced') 

97 

98 if variant == "research_accurate": 

99 self._forward_impl = self._forward_research_accurate 

100 elif variant == "simple": 

101 self._forward_impl = self._forward_simple 

102 elif variant == "original": 

103 self._forward_impl = self._forward_original 

104 else: # enhanced 

105 self._forward_impl = self._forward_enhanced 

106 

107 def _forward_research_accurate( 

108 self, 

109 support_x: torch.Tensor, 

110 support_y: torch.Tensor, 

111 query_x: torch.Tensor, 

112 return_uncertainty: bool = False 

113 ) -> Dict[str, torch.Tensor]: 

114 """Research-accurate implementation following Snell et al. (2017) exactly.""" 

115 # Embed support and query examples 

116 support_features = self.backbone(support_x) 

117 query_features = self.backbone(query_x) 

118 

119 # Compute class prototypes 

120 n_way = len(torch.unique(support_y)) 

121 prototypes = torch.zeros(n_way, support_features.size(1), device=support_features.device) 

122 

123 for k in range(n_way): 

124 class_mask = support_y == k 

125 if class_mask.any(): 

126 class_features = support_features[class_mask] 

127 prototypes[k] = class_features.mean(dim=0) 

128 

129 # Compute squared Euclidean distances 

130 distances = torch.cdist(query_features, prototypes, p=2) ** 2 

131 

132 # Convert to logits via negative distances with temperature 

133 temperature = getattr(self.config, 'distance_temperature', 1.0) 

134 logits = -distances / temperature 

135 

136 result = {"logits": logits} 

137 

138 if return_uncertainty: 

139 probs = F.softmax(logits, dim=-1) 

140 entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1) 

141 result["uncertainty"] = entropy 

142 

143 return result 

144 

145 def _forward_simple(self, support_x, support_y, query_x, return_uncertainty=False): 

146 """Simplified implementation without extensions.""" 

147 simple_protonet = SimplePrototypicalNetworks(self.backbone) 

148 logits = simple_protonet.forward(support_x, support_y, query_x) 

149 return {"logits": logits} 

150 

151 def _forward_original(self, support_x, support_y, query_x, return_uncertainty=False): 

152 """Original implementation (preserved for backward compatibility).""" 

153 return self._forward_enhanced(support_x, support_y, query_x, return_uncertainty) 

154 

155 def _forward_enhanced(self, support_x, support_y, query_x, return_uncertainty=False): 

156 """Enhanced implementation with all features.""" 

157 # Extract features 

158 support_features = self.backbone(support_x) 

159 query_features = self.backbone(query_x) 

160 

161 # Multi-scale features if configured 

162 if self.config.multi_scale_features and hasattr(self, 'scale_aggregator'): 

163 support_features = self.scale_aggregator(support_features, support_x) 

164 query_features = self.scale_aggregator(query_features, query_x) 

165 

166 # Compute initial prototypes 

167 prototypes = self._compute_prototypes(support_features, support_y) 

168 

169 # Adaptive refinement if configured 

170 if self.config.adaptive_prototypes and hasattr(self, 'prototype_refiner'): 

171 prototypes = self.prototype_refiner(prototypes, support_features, support_y) 

172 

173 # Compute distances 

174 distances = self._compute_distances(query_features, prototypes) 

175 logits = -distances / self.config.temperature 

176 

177 result = {"logits": logits} 

178 

179 # Uncertainty estimation if requested 

180 if (return_uncertainty and hasattr(self.config, 'uncertainty_estimation') 

181 and self.config.uncertainty_estimation and hasattr(self, 'uncertainty_estimator')): 

182 uncertainty = self.uncertainty_estimator(query_features, prototypes, distances) 

183 result["uncertainty"] = uncertainty 

184 

185 return result 

186 

187 def _compute_prototypes(self, support_features, support_y): 

188 """Compute class prototypes from support set.""" 

189 unique_classes = torch.unique(support_y) 

190 prototypes = [] 

191 

192 for class_id in unique_classes: 

193 class_mask = support_y == class_id 

194 class_features = support_features[class_mask] 

195 class_prototype = class_features.mean(dim=0) 

196 prototypes.append(class_prototype) 

197 

198 return torch.stack(prototypes) 

199 

200 def _compute_distances(self, query_features, prototypes): 

201 """Compute distances between queries and prototypes.""" 

202 query_expanded = query_features.unsqueeze(1) 

203 proto_expanded = prototypes.unsqueeze(0) 

204 distances = torch.sum((query_expanded - proto_expanded) ** 2, dim=-1) 

205 return distances 

206 

207 

208class SimplePrototypicalNetworks: 

209 """ 

210 Research-accurate implementation of Prototypical Networks (Snell et al. 2017). 

211  

212 Core algorithm: 

213 1. Compute class prototypes: c_k = 1/|S_k| Σ f_φ(x_i) for (x_i,y_i) ∈ S_k 

214 2. Classify via softmax over negative squared distances 

215 3. Distance: d(f_φ(x), c_k) = ||f_φ(x) - c_k||² 

216 """ 

217 

218 def __init__(self, embedding_net: nn.Module): 

219 """Initialize with embedding network f_φ.""" 

220 self.embedding_net = embedding_net 

221 

222 def forward(self, support_x, support_y, query_x): 

223 """Standard Prototypical Networks forward pass.""" 

224 # Embed support and query examples 

225 support_features = self.embedding_net(support_x) 

226 query_features = self.embedding_net(query_x) 

227 

228 # Compute class prototypes 

229 n_way = len(torch.unique(support_y)) 

230 prototypes = torch.zeros(n_way, support_features.size(1), device=support_features.device) 

231 

232 for k in range(n_way): 

233 class_mask = support_y == k 

234 if class_mask.any(): 

235 class_examples = support_features[class_mask] 

236 prototypes[k] = class_examples.mean(dim=0) 

237 

238 # Compute distances and convert to logits 

239 distances = torch.cdist(query_features, prototypes, p=2) ** 2 

240 logits = -distances 

241 

242 return logits 

243 

244 

245class MatchingNetworks: 

246 """ 

247 Advanced Matching Networks with 2024 attention mechanisms. 

248  

249 Key innovations beyond existing libraries: 

250 1. Multi-head attention for support-query matching 

251 2. Bidirectional LSTM context encoding 

252 3. Transformer-based support set encoding 

253 4. Adaptive attention temperature 

254 5. Context-aware similarity metrics 

255 """ 

256 

257 def __init__(self, backbone: nn.Module, config: MatchingConfig = None): 

258 """Initialize advanced Matching Networks.""" 

259 self.backbone = backbone 

260 self.config = config or MatchingConfig() 

261 

262 # Context encoding for support set 

263 if getattr(self.config, 'use_lstm', True): 263 ↛ 277line 263 didn't jump to line 277 because the condition on line 263 was always true

264 self.context_encoder = nn.LSTM( 

265 self.config.embedding_dim, 

266 getattr(self.config, 'lstm_layers', 256), 

267 bidirectional=getattr(self.config, 'bidirectional', True), 

268 batch_first=True 

269 ) 

270 hidden_multiplier = 2 if getattr(self.config, 'bidirectional', True) else 1 

271 self.context_projection = nn.Linear( 

272 getattr(self.config, 'lstm_layers', 256) * hidden_multiplier, 

273 self.config.embedding_dim 

274 ) 

275 

276 # Attention mechanism 

277 self.attention = self._create_attention_mechanism() 

278 

279 # Adaptive temperature 

280 self.temperature_net = nn.Sequential( 

281 nn.Linear(self.config.embedding_dim, 64), 

282 nn.ReLU(), 

283 nn.Linear(64, 1), 

284 nn.Softplus() 

285 ) 

286 

287 logger.info(f"Initialized Advanced Matching Networks: {self.config}") 

288 

289 def forward(self, support_x, support_y, query_x): 

290 """Forward pass with advanced matching networks.""" 

291 # Extract features 

292 support_features = self.backbone(support_x) 

293 query_features = self.backbone(query_x) 

294 

295 # Context encoding for support set 

296 if hasattr(self, 'context_encoder'): 

297 support_features = self._encode_context(support_features) 

298 

299 # Compute attention weights 

300 attention_weights = self.attention(query_features, support_features, support_features) 

301 

302 # Adaptive temperature 

303 temperatures = self.temperature_net(query_features.mean(dim=0)) 

304 temperatures = temperatures.clamp(min=0.1, max=10.0) 

305 

306 # Apply temperature scaling 

307 scaled_attention = attention_weights / temperatures 

308 attention_probs = F.softmax(scaled_attention, dim=-1) 

309 

310 # Convert to predictions 

311 n_classes = len(torch.unique(support_y)) 

312 support_one_hot = F.one_hot(support_y, n_classes).float() 

313 predictions = torch.matmul(attention_probs, support_one_hot) 

314 logits = torch.log(predictions + 1e-8) 

315 

316 return { 

317 "logits": logits, 

318 "probabilities": predictions, 

319 "attention_weights": attention_weights 

320 } 

321 

322 def _encode_context(self, support_features): 

323 """Encode support set with contextual information.""" 

324 support_expanded = support_features.unsqueeze(0) 

325 encoded, _ = self.context_encoder(support_expanded) 

326 encoded = self.context_projection(encoded) 

327 return encoded.squeeze(0) 

328 

329 def _create_attention_mechanism(self): 

330 """Create attention mechanism based on configuration.""" 

331 attention_type = getattr(self.config, 'attention_type', 'cosine') 

332 

333 if attention_type == "scaled_dot_product": 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true

334 return ScaledDotProductAttention( 

335 self.config.embedding_dim, 

336 getattr(self.config, 'num_attention_heads', 8), 

337 self.config.dropout 

338 ) 

339 elif attention_type == "additive": 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true

340 return AdditiveAttention(self.config.embedding_dim) 

341 elif attention_type == "bilinear": 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true

342 return BilinearAttention(self.config.embedding_dim) 

343 else: 

344 # Default cosine attention 

345 return ScaledDotProductAttention( 

346 self.config.embedding_dim, 8, self.config.dropout 

347 ) 

348 

349 

350class RelationNetworks: 

351 """ 

352 Advanced Relation Networks with Graph Neural Network components (2024). 

353  

354 Key innovations beyond existing libraries: 

355 1. Graph Neural Network for relation modeling 

356 2. Edge features and message passing 

357 3. Self-attention for relation refinement 

358 4. Hierarchical relation structures 

359 5. Multi-hop reasoning capabilities 

360 """ 

361 

362 def __init__(self, backbone: nn.Module, config: RelationConfig = None): 

363 """Initialize advanced Relation Networks.""" 

364 self.backbone = backbone 

365 self.config = config or RelationConfig() 

366 

367 # Relation module 

368 if getattr(self.config, 'use_graph_neural_network', True): 368 ↛ 378line 368 didn't jump to line 378 because the condition on line 368 was always true

369 self.relation_module = GraphRelationModule( 

370 self.config.embedding_dim, 

371 self.config.relation_dim, 

372 getattr(self.config, 'gnn_layers', 3), 

373 getattr(self.config, 'gnn_hidden_dim', 256), 

374 getattr(self.config, 'edge_features', True), 

375 getattr(self.config, 'message_passing_steps', 3) 

376 ) 

377 else: 

378 self.relation_module = StandardRelationModule( 

379 self.config.embedding_dim, 

380 self.config.relation_dim 

381 ) 

382 

383 # Self-attention for relation refinement 

384 if getattr(self.config, 'self_attention', True): 384 ↛ 392line 384 didn't jump to line 392 because the condition on line 384 was always true

385 self.self_attention = nn.MultiheadAttention( 

386 self.config.embedding_dim, 

387 num_heads=8, 

388 dropout=self.config.dropout, 

389 batch_first=True 

390 ) 

391 

392 logger.info(f"Initialized Advanced Relation Networks: {self.config}") 

393 

394 def forward(self, support_x, support_y, query_x): 

395 """Forward pass with advanced relation networks.""" 

396 # Extract features 

397 support_features = self.backbone(support_x) 

398 query_features = self.backbone(query_x) 

399 

400 # Self-attention refinement 

401 if hasattr(self, 'self_attention'): 

402 support_features, _ = self.self_attention( 

403 support_features.unsqueeze(0), 

404 support_features.unsqueeze(0), 

405 support_features.unsqueeze(0) 

406 ) 

407 support_features = support_features.squeeze(0) 

408 

409 # Compute relations 

410 relation_scores = self.relation_module( 

411 query_features, support_features, support_y 

412 ) 

413 

414 # Convert to class predictions 

415 predictions = self._aggregate_relation_scores(relation_scores, support_y) 

416 

417 return { 

418 "logits": predictions, 

419 "probabilities": F.softmax(predictions, dim=-1), 

420 "relation_scores": relation_scores 

421 } 

422 

423 def _aggregate_relation_scores(self, relation_scores, support_y): 

424 """Aggregate relation scores to class-level predictions.""" 

425 unique_classes = torch.unique(support_y) 

426 n_query = relation_scores.shape[0] 

427 n_classes = len(unique_classes) 

428 

429 class_scores = torch.zeros(n_query, n_classes, device=relation_scores.device) 

430 

431 for i, class_id in enumerate(unique_classes): 

432 class_mask = support_y == class_id 

433 class_relations = relation_scores[:, class_mask] 

434 class_scores[:, i] = class_relations.mean(dim=-1) 

435 

436 return class_scores