Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/core_networks.py: 27%
178 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
1"""
2Few-Shot Learning Core Network Architectures
3==========================================
5Core neural network implementations for few-shot learning algorithms.
6Extracted from the original monolithic few_shot_learning.py.
7"""
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12from typing import Dict, List, Tuple, Optional, Any
13import numpy as np
14import logging
16from .configurations import FewShotConfig, PrototypicalConfig, MatchingConfig, RelationConfig
17from .advanced_components import (
18 MultiScaleFeatureAggregator, PrototypeRefiner, UncertaintyEstimator,
19 ScaledDotProductAttention, AdditiveAttention, BilinearAttention,
20 GraphRelationModule, StandardRelationModule,
21 UncertaintyAwareDistance, HierarchicalPrototypes, TaskAdaptivePrototypes
22)
24logger = logging.getLogger(__name__)
27class PrototypicalNetworks:
28 """
29 Advanced Prototypical Networks with 2024 improvements.
31 Based on Snell et al. (2017) "Prototypical Networks for Few-shot Learning"
32 with research-accurate extensions and configurable variants.
33 """
35 def __init__(self, backbone: nn.Module, config: PrototypicalConfig = None):
36 """Initialize advanced Prototypical Networks."""
37 self.backbone = backbone
38 self.config = config or PrototypicalConfig()
40 # Multi-scale feature aggregation
41 if self.config.multi_scale_features: 41 ↛ 48line 41 didn't jump to line 48 because the condition on line 41 was always true
42 self.scale_aggregator = MultiScaleFeatureAggregator(
43 self.config.embedding_dim,
44 self.config.scale_factors
45 )
47 # Adaptive prototype refinement
48 if self.config.adaptive_prototypes:
49 self.prototype_refiner = PrototypeRefiner(
50 self.config.embedding_dim,
51 self.config.prototype_refinement_steps
52 )
54 # Uncertainty estimation
55 if hasattr(self.config, 'uncertainty_estimation') and self.config.uncertainty_estimation:
56 self.uncertainty_estimator = UncertaintyEstimator(
57 self.config.embedding_dim
58 )
60 # Advanced components based on config
61 if hasattr(self.config, 'use_uncertainty_aware_distances') and self.config.use_uncertainty_aware_distances:
62 self.uncertainty_distance = UncertaintyAwareDistance(
63 self.config.embedding_dim,
64 getattr(self.config, 'uncertainty_temperature', 2.0)
65 )
67 if hasattr(self.config, 'use_hierarchical_prototypes') and self.config.use_hierarchical_prototypes:
68 self.hierarchical_prototypes = HierarchicalPrototypes(
69 self.config.embedding_dim,
70 getattr(self.config, 'hierarchy_levels', 2)
71 )
73 if hasattr(self.config, 'use_task_adaptive_prototypes') and self.config.use_task_adaptive_prototypes:
74 self.adaptive_initializer = TaskAdaptivePrototypes(
75 self.config.embedding_dim,
76 getattr(self.config, 'adaptation_steps', 5)
77 )
79 logger.info(f"Initialized Advanced Prototypical Networks: {self.config}")
80 self._setup_implementation_variant()
82 def forward(
83 self,
84 support_x: torch.Tensor,
85 support_y: torch.Tensor,
86 query_x: torch.Tensor,
87 return_uncertainty: bool = False
88 ) -> Dict[str, torch.Tensor]:
89 """
90 Configurable forward pass that routes to appropriate implementation.
91 """
92 return self._forward_impl(support_x, support_y, query_x, return_uncertainty)
94 def _setup_implementation_variant(self):
95 """Setup the appropriate implementation based on configuration."""
96 variant = getattr(self.config, 'protonet_variant', 'enhanced')
98 if variant == "research_accurate":
99 self._forward_impl = self._forward_research_accurate
100 elif variant == "simple":
101 self._forward_impl = self._forward_simple
102 elif variant == "original":
103 self._forward_impl = self._forward_original
104 else: # enhanced
105 self._forward_impl = self._forward_enhanced
107 def _forward_research_accurate(
108 self,
109 support_x: torch.Tensor,
110 support_y: torch.Tensor,
111 query_x: torch.Tensor,
112 return_uncertainty: bool = False
113 ) -> Dict[str, torch.Tensor]:
114 """Research-accurate implementation following Snell et al. (2017) exactly."""
115 # Embed support and query examples
116 support_features = self.backbone(support_x)
117 query_features = self.backbone(query_x)
119 # Compute class prototypes
120 n_way = len(torch.unique(support_y))
121 prototypes = torch.zeros(n_way, support_features.size(1), device=support_features.device)
123 for k in range(n_way):
124 class_mask = support_y == k
125 if class_mask.any():
126 class_features = support_features[class_mask]
127 prototypes[k] = class_features.mean(dim=0)
129 # Compute squared Euclidean distances
130 distances = torch.cdist(query_features, prototypes, p=2) ** 2
132 # Convert to logits via negative distances with temperature
133 temperature = getattr(self.config, 'distance_temperature', 1.0)
134 logits = -distances / temperature
136 result = {"logits": logits}
138 if return_uncertainty:
139 probs = F.softmax(logits, dim=-1)
140 entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=-1)
141 result["uncertainty"] = entropy
143 return result
145 def _forward_simple(self, support_x, support_y, query_x, return_uncertainty=False):
146 """Simplified implementation without extensions."""
147 simple_protonet = SimplePrototypicalNetworks(self.backbone)
148 logits = simple_protonet.forward(support_x, support_y, query_x)
149 return {"logits": logits}
151 def _forward_original(self, support_x, support_y, query_x, return_uncertainty=False):
152 """Original implementation (preserved for backward compatibility)."""
153 return self._forward_enhanced(support_x, support_y, query_x, return_uncertainty)
155 def _forward_enhanced(self, support_x, support_y, query_x, return_uncertainty=False):
156 """Enhanced implementation with all features."""
157 # Extract features
158 support_features = self.backbone(support_x)
159 query_features = self.backbone(query_x)
161 # Multi-scale features if configured
162 if self.config.multi_scale_features and hasattr(self, 'scale_aggregator'):
163 support_features = self.scale_aggregator(support_features, support_x)
164 query_features = self.scale_aggregator(query_features, query_x)
166 # Compute initial prototypes
167 prototypes = self._compute_prototypes(support_features, support_y)
169 # Adaptive refinement if configured
170 if self.config.adaptive_prototypes and hasattr(self, 'prototype_refiner'):
171 prototypes = self.prototype_refiner(prototypes, support_features, support_y)
173 # Compute distances
174 distances = self._compute_distances(query_features, prototypes)
175 logits = -distances / self.config.temperature
177 result = {"logits": logits}
179 # Uncertainty estimation if requested
180 if (return_uncertainty and hasattr(self.config, 'uncertainty_estimation')
181 and self.config.uncertainty_estimation and hasattr(self, 'uncertainty_estimator')):
182 uncertainty = self.uncertainty_estimator(query_features, prototypes, distances)
183 result["uncertainty"] = uncertainty
185 return result
187 def _compute_prototypes(self, support_features, support_y):
188 """Compute class prototypes from support set."""
189 unique_classes = torch.unique(support_y)
190 prototypes = []
192 for class_id in unique_classes:
193 class_mask = support_y == class_id
194 class_features = support_features[class_mask]
195 class_prototype = class_features.mean(dim=0)
196 prototypes.append(class_prototype)
198 return torch.stack(prototypes)
200 def _compute_distances(self, query_features, prototypes):
201 """Compute distances between queries and prototypes."""
202 query_expanded = query_features.unsqueeze(1)
203 proto_expanded = prototypes.unsqueeze(0)
204 distances = torch.sum((query_expanded - proto_expanded) ** 2, dim=-1)
205 return distances
208class SimplePrototypicalNetworks:
209 """
210 Research-accurate implementation of Prototypical Networks (Snell et al. 2017).
212 Core algorithm:
213 1. Compute class prototypes: c_k = 1/|S_k| Σ f_φ(x_i) for (x_i,y_i) ∈ S_k
214 2. Classify via softmax over negative squared distances
215 3. Distance: d(f_φ(x), c_k) = ||f_φ(x) - c_k||²
216 """
218 def __init__(self, embedding_net: nn.Module):
219 """Initialize with embedding network f_φ."""
220 self.embedding_net = embedding_net
222 def forward(self, support_x, support_y, query_x):
223 """Standard Prototypical Networks forward pass."""
224 # Embed support and query examples
225 support_features = self.embedding_net(support_x)
226 query_features = self.embedding_net(query_x)
228 # Compute class prototypes
229 n_way = len(torch.unique(support_y))
230 prototypes = torch.zeros(n_way, support_features.size(1), device=support_features.device)
232 for k in range(n_way):
233 class_mask = support_y == k
234 if class_mask.any():
235 class_examples = support_features[class_mask]
236 prototypes[k] = class_examples.mean(dim=0)
238 # Compute distances and convert to logits
239 distances = torch.cdist(query_features, prototypes, p=2) ** 2
240 logits = -distances
242 return logits
245class MatchingNetworks:
246 """
247 Advanced Matching Networks with 2024 attention mechanisms.
249 Key innovations beyond existing libraries:
250 1. Multi-head attention for support-query matching
251 2. Bidirectional LSTM context encoding
252 3. Transformer-based support set encoding
253 4. Adaptive attention temperature
254 5. Context-aware similarity metrics
255 """
257 def __init__(self, backbone: nn.Module, config: MatchingConfig = None):
258 """Initialize advanced Matching Networks."""
259 self.backbone = backbone
260 self.config = config or MatchingConfig()
262 # Context encoding for support set
263 if getattr(self.config, 'use_lstm', True): 263 ↛ 277line 263 didn't jump to line 277 because the condition on line 263 was always true
264 self.context_encoder = nn.LSTM(
265 self.config.embedding_dim,
266 getattr(self.config, 'lstm_layers', 256),
267 bidirectional=getattr(self.config, 'bidirectional', True),
268 batch_first=True
269 )
270 hidden_multiplier = 2 if getattr(self.config, 'bidirectional', True) else 1
271 self.context_projection = nn.Linear(
272 getattr(self.config, 'lstm_layers', 256) * hidden_multiplier,
273 self.config.embedding_dim
274 )
276 # Attention mechanism
277 self.attention = self._create_attention_mechanism()
279 # Adaptive temperature
280 self.temperature_net = nn.Sequential(
281 nn.Linear(self.config.embedding_dim, 64),
282 nn.ReLU(),
283 nn.Linear(64, 1),
284 nn.Softplus()
285 )
287 logger.info(f"Initialized Advanced Matching Networks: {self.config}")
289 def forward(self, support_x, support_y, query_x):
290 """Forward pass with advanced matching networks."""
291 # Extract features
292 support_features = self.backbone(support_x)
293 query_features = self.backbone(query_x)
295 # Context encoding for support set
296 if hasattr(self, 'context_encoder'):
297 support_features = self._encode_context(support_features)
299 # Compute attention weights
300 attention_weights = self.attention(query_features, support_features, support_features)
302 # Adaptive temperature
303 temperatures = self.temperature_net(query_features.mean(dim=0))
304 temperatures = temperatures.clamp(min=0.1, max=10.0)
306 # Apply temperature scaling
307 scaled_attention = attention_weights / temperatures
308 attention_probs = F.softmax(scaled_attention, dim=-1)
310 # Convert to predictions
311 n_classes = len(torch.unique(support_y))
312 support_one_hot = F.one_hot(support_y, n_classes).float()
313 predictions = torch.matmul(attention_probs, support_one_hot)
314 logits = torch.log(predictions + 1e-8)
316 return {
317 "logits": logits,
318 "probabilities": predictions,
319 "attention_weights": attention_weights
320 }
322 def _encode_context(self, support_features):
323 """Encode support set with contextual information."""
324 support_expanded = support_features.unsqueeze(0)
325 encoded, _ = self.context_encoder(support_expanded)
326 encoded = self.context_projection(encoded)
327 return encoded.squeeze(0)
329 def _create_attention_mechanism(self):
330 """Create attention mechanism based on configuration."""
331 attention_type = getattr(self.config, 'attention_type', 'cosine')
333 if attention_type == "scaled_dot_product": 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true
334 return ScaledDotProductAttention(
335 self.config.embedding_dim,
336 getattr(self.config, 'num_attention_heads', 8),
337 self.config.dropout
338 )
339 elif attention_type == "additive": 339 ↛ 340line 339 didn't jump to line 340 because the condition on line 339 was never true
340 return AdditiveAttention(self.config.embedding_dim)
341 elif attention_type == "bilinear": 341 ↛ 342line 341 didn't jump to line 342 because the condition on line 341 was never true
342 return BilinearAttention(self.config.embedding_dim)
343 else:
344 # Default cosine attention
345 return ScaledDotProductAttention(
346 self.config.embedding_dim, 8, self.config.dropout
347 )
350class RelationNetworks:
351 """
352 Advanced Relation Networks with Graph Neural Network components (2024).
354 Key innovations beyond existing libraries:
355 1. Graph Neural Network for relation modeling
356 2. Edge features and message passing
357 3. Self-attention for relation refinement
358 4. Hierarchical relation structures
359 5. Multi-hop reasoning capabilities
360 """
362 def __init__(self, backbone: nn.Module, config: RelationConfig = None):
363 """Initialize advanced Relation Networks."""
364 self.backbone = backbone
365 self.config = config or RelationConfig()
367 # Relation module
368 if getattr(self.config, 'use_graph_neural_network', True): 368 ↛ 378line 368 didn't jump to line 378 because the condition on line 368 was always true
369 self.relation_module = GraphRelationModule(
370 self.config.embedding_dim,
371 self.config.relation_dim,
372 getattr(self.config, 'gnn_layers', 3),
373 getattr(self.config, 'gnn_hidden_dim', 256),
374 getattr(self.config, 'edge_features', True),
375 getattr(self.config, 'message_passing_steps', 3)
376 )
377 else:
378 self.relation_module = StandardRelationModule(
379 self.config.embedding_dim,
380 self.config.relation_dim
381 )
383 # Self-attention for relation refinement
384 if getattr(self.config, 'self_attention', True): 384 ↛ 392line 384 didn't jump to line 392 because the condition on line 384 was always true
385 self.self_attention = nn.MultiheadAttention(
386 self.config.embedding_dim,
387 num_heads=8,
388 dropout=self.config.dropout,
389 batch_first=True
390 )
392 logger.info(f"Initialized Advanced Relation Networks: {self.config}")
394 def forward(self, support_x, support_y, query_x):
395 """Forward pass with advanced relation networks."""
396 # Extract features
397 support_features = self.backbone(support_x)
398 query_features = self.backbone(query_x)
400 # Self-attention refinement
401 if hasattr(self, 'self_attention'):
402 support_features, _ = self.self_attention(
403 support_features.unsqueeze(0),
404 support_features.unsqueeze(0),
405 support_features.unsqueeze(0)
406 )
407 support_features = support_features.squeeze(0)
409 # Compute relations
410 relation_scores = self.relation_module(
411 query_features, support_features, support_y
412 )
414 # Convert to class predictions
415 predictions = self._aggregate_relation_scores(relation_scores, support_y)
417 return {
418 "logits": predictions,
419 "probabilities": F.softmax(predictions, dim=-1),
420 "relation_scores": relation_scores
421 }
423 def _aggregate_relation_scores(self, relation_scores, support_y):
424 """Aggregate relation scores to class-level predictions."""
425 unique_classes = torch.unique(support_y)
426 n_query = relation_scores.shape[0]
427 n_classes = len(unique_classes)
429 class_scores = torch.zeros(n_query, n_classes, device=relation_scores.device)
431 for i, class_id in enumerate(unique_classes):
432 class_mask = support_y == class_id
433 class_relations = relation_scores[:, class_mask]
434 class_scores[:, i] = class_relations.mean(dim=-1)
436 return class_scores