Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/utilities.py: 9%
145 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:35 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:35 +0900
1"""
2Few-Shot Learning Utilities
3==========================
5Utility functions for few-shot learning including factory functions,
6evaluation utilities, and helper functions.
7Extracted from the original monolithic few_shot_learning.py.
8"""
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13from typing import Dict, List, Tuple, Optional, Any
14import numpy as np
15import logging
17from .configurations import PrototypicalConfig
18from .core_networks import PrototypicalNetworks
20logger = logging.getLogger(__name__)
23def create_prototypical_network(
24 backbone: nn.Module,
25 variant: str = "research_accurate",
26 config: PrototypicalConfig = None
27) -> PrototypicalNetworks:
28 """
29 Factory function to create Prototypical Networks with specific configuration.
31 Args:
32 backbone: Feature extraction backbone network
33 variant: Implementation variant ('research_accurate', 'simple', 'enhanced', 'original')
34 config: Optional custom configuration
36 Returns:
37 Configured PrototypicalNetworks instance
38 """
39 if config is None:
40 config = PrototypicalConfig()
42 # Set variant-specific configuration
43 if hasattr(config, 'protonet_variant'):
44 config.protonet_variant = variant
46 # Configure based on variant
47 if variant == "research_accurate":
48 # Pure research-accurate implementation
49 if hasattr(config, 'use_squared_euclidean'):
50 config.use_squared_euclidean = True
51 if hasattr(config, 'prototype_method'):
52 config.prototype_method = "mean"
53 if hasattr(config, 'enable_research_extensions'):
54 config.enable_research_extensions = False
55 config.multi_scale_features = False
56 config.adaptive_prototypes = False
57 if hasattr(config, 'uncertainty_estimation'):
58 config.uncertainty_estimation = False
60 elif variant == "simple":
61 # Simplified educational version
62 config.multi_scale_features = False
63 config.adaptive_prototypes = False
64 if hasattr(config, 'uncertainty_estimation'):
65 config.uncertainty_estimation = False
66 if hasattr(config, 'enable_research_extensions'):
67 config.enable_research_extensions = False
69 elif variant == "enhanced":
70 # All extensions enabled
71 config.multi_scale_features = True
72 config.adaptive_prototypes = True
73 if hasattr(config, 'uncertainty_estimation'):
74 config.uncertainty_estimation = True
75 if hasattr(config, 'enable_research_extensions'):
76 config.enable_research_extensions = True
78 return PrototypicalNetworks(backbone, config)
81def compare_with_learn2learn_protonet():
82 """
83 Comparison with learn2learn's Prototypical Networks implementation.
85 learn2learn approach:
86 ```python
87 import learn2learn as l2l
89 # Create prototypical network head
90 head = l2l.algorithms.Lightning(
91 l2l.utils.ProtoLightning,
92 ways=5,
93 shots=5,
94 model=backbone
95 )
97 # Standard training loop
98 for batch in dataloader:
99 support, query = batch
100 loss = head.forward(support, query)
101 loss.backward()
102 optimizer.step()
103 ```
105 Key differences from our implementation:
106 1. learn2learn uses Lightning framework for training automation
107 2. They provide built-in data loaders for standard benchmarks
108 3. Our implementation is more educational/research-focused
109 4. learn2learn handles meta-batch processing automatically
110 """
111 comparison_info = {
112 "learn2learn_advantages": [
113 "Lightning framework integration",
114 "Built-in benchmark data loaders",
115 "Automatic meta-batch processing",
116 "Production-ready training loops"
117 ],
118 "our_advantages": [
119 "Educational and research-focused",
120 "Research-accurate implementations",
121 "Configurable variants",
122 "Extensive documentation and citations",
123 "Advanced extensions with proper attribution"
124 ],
125 "use_cases": {
126 "learn2learn": "Production systems, quick prototyping",
127 "our_implementation": "Research, education, algorithm understanding"
128 }
129 }
131 return comparison_info
134def evaluate_on_standard_benchmarks(model, dataset_name="omniglot", episodes=600):
135 """
136 Standard few-shot evaluation following research protocols.
138 Based on standard evaluation in meta-learning literature:
139 - Omniglot: 20-way 1-shot and 5-shot
140 - miniImageNet: 5-way 1-shot and 5-shot
141 - tieredImageNet: 5-way 1-shot and 5-shot
143 Returns confidence intervals over specified episodes (standard: 600).
145 Args:
146 model: Few-shot learning model
147 dataset_name: Name of benchmark dataset
148 episodes: Number of evaluation episodes
150 Returns:
151 Dictionary with mean accuracy and confidence interval
152 """
153 accuracies = []
155 for episode in range(episodes):
156 try:
157 # Sample episode (N-way K-shot)
158 support_x, support_y, query_x, query_y = sample_episode(dataset_name)
160 # Forward pass
161 logits = model(support_x, support_y, query_x)
162 if isinstance(logits, dict):
163 logits = logits['logits']
165 predictions = logits.argmax(dim=1)
167 # Compute accuracy
168 accuracy = (predictions == query_y).float().mean()
169 accuracies.append(accuracy.item())
171 except Exception as e:
172 logger.warning(f"Episode {episode} failed: {e}")
173 continue
175 if len(accuracies) == 0:
176 return {"mean_accuracy": 0.0, "confidence_interval": 0.0, "episodes": 0}
178 # Compute 95% confidence interval
179 mean_acc = np.mean(accuracies)
180 std_acc = np.std(accuracies)
181 ci = 1.96 * std_acc / np.sqrt(len(accuracies)) # 95% CI
183 return {
184 "mean_accuracy": mean_acc,
185 "confidence_interval": ci,
186 "std_accuracy": std_acc,
187 "episodes": len(accuracies),
188 "raw_accuracies": accuracies
189 }
192def sample_episode(dataset_name: str, n_way: int = 5, n_support: int = 5, n_query: int = 15):
193 """
194 Sample a few-shot episode from the specified dataset.
196 This is a placeholder implementation for demonstration.
197 In practice, you would integrate with actual dataset loaders.
199 Args:
200 dataset_name: Name of the dataset
201 n_way: Number of classes per episode
202 n_support: Number of support examples per class
203 n_query: Number of query examples per class
205 Returns:
206 Tuple of (support_x, support_y, query_x, query_y)
207 """
208 # Placeholder implementation - replace with actual dataset loading
209 if dataset_name == "omniglot":
210 input_size = (1, 28, 28)
211 n_way = 20 # Standard for Omniglot
212 elif dataset_name in ["miniImageNet", "tieredImageNet"]:
213 input_size = (3, 84, 84)
214 n_way = 5 # Standard for ImageNet variants
215 else:
216 input_size = (3, 32, 32) # Default
218 # Generate synthetic data for demonstration
219 support_x = torch.randn(n_way * n_support, *input_size)
220 support_y = torch.repeat_interleave(torch.arange(n_way), n_support)
222 query_x = torch.randn(n_way * n_query, *input_size)
223 query_y = torch.repeat_interleave(torch.arange(n_way), n_query)
225 return support_x, support_y, query_x, query_y
228def euclidean_distance_squared(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
229 """
230 Squared Euclidean distance as in Snell et al. (2017) Equation 1.
232 Args:
233 x: Query embeddings [n_query, embedding_dim]
234 y: Prototype embeddings [n_prototypes, embedding_dim]
236 Returns:
237 Squared distances [n_query, n_prototypes]
238 """
239 # Expand for broadcasting
240 x_expanded = x.unsqueeze(1) # [n_query, 1, embedding_dim]
241 y_expanded = y.unsqueeze(0) # [1, n_prototypes, embedding_dim]
243 # Compute squared Euclidean distance for gradient stability
244 return torch.sum((x_expanded - y_expanded)**2, dim=-1)
247def compute_prototype_statistics(prototypes: torch.Tensor, support_features: torch.Tensor,
248 support_labels: torch.Tensor) -> Dict[str, float]:
249 """
250 Compute statistics about learned prototypes for analysis.
252 Args:
253 prototypes: Class prototypes [n_classes, embedding_dim]
254 support_features: Support set features [n_support, embedding_dim]
255 support_labels: Support set labels [n_support]
257 Returns:
258 Dictionary with prototype statistics
259 """
260 stats = {}
262 # Inter-prototype distances
263 proto_distances = torch.cdist(prototypes, prototypes, p=2)
264 # Remove diagonal (self-distances)
265 mask = ~torch.eye(len(prototypes), dtype=bool)
266 inter_distances = proto_distances[mask]
268 stats['mean_inter_prototype_distance'] = inter_distances.mean().item()
269 stats['std_inter_prototype_distance'] = inter_distances.std().item()
270 stats['min_inter_prototype_distance'] = inter_distances.min().item()
271 stats['max_inter_prototype_distance'] = inter_distances.max().item()
273 # Intra-class distances (support examples to their prototype)
274 intra_distances = []
275 for class_idx in torch.unique(support_labels):
276 class_mask = support_labels == class_idx
277 class_features = support_features[class_mask]
278 class_prototype = prototypes[class_idx]
280 # Distances from class examples to prototype
281 distances = torch.norm(class_features - class_prototype, p=2, dim=1)
282 intra_distances.append(distances)
284 all_intra = torch.cat(intra_distances)
285 stats['mean_intra_class_distance'] = all_intra.mean().item()
286 stats['std_intra_class_distance'] = all_intra.std().item()
288 # Prototype quality metric (higher is better separation)
289 separation_ratio = stats['mean_inter_prototype_distance'] / (stats['mean_intra_class_distance'] + 1e-8)
290 stats['prototype_separation_ratio'] = separation_ratio
292 return stats
295def analyze_few_shot_performance(model, test_episodes: int = 100, n_way: int = 5,
296 n_support: int = 5, n_query: int = 15) -> Dict[str, Any]:
297 """
298 Comprehensive analysis of few-shot learning performance.
300 Args:
301 model: Few-shot learning model
302 test_episodes: Number of test episodes
303 n_way: Number of classes per episode
304 n_support: Number of support examples per class
305 n_query: Number of query examples per class
307 Returns:
308 Comprehensive performance analysis
309 """
310 model.eval()
312 episode_accuracies = []
313 prototype_stats_list = []
314 confidence_scores = []
316 with torch.no_grad():
317 for episode in range(test_episodes):
318 # Sample episode
319 support_x, support_y, query_x, query_y = sample_episode(
320 "synthetic", n_way, n_support, n_query
321 )
323 try:
324 # Forward pass
325 result = model(support_x, support_y, query_x)
326 if isinstance(result, dict):
327 logits = result['logits']
328 prototypes = result.get('prototypes')
329 else:
330 logits = result
331 prototypes = None
333 # Compute accuracy
334 predictions = logits.argmax(dim=1)
335 accuracy = (predictions == query_y).float().mean().item()
336 episode_accuracies.append(accuracy)
338 # Analyze prototypes if available
339 if prototypes is not None:
340 support_features = model.backbone(support_x)
341 proto_stats = compute_prototype_statistics(
342 prototypes, support_features, support_y
343 )
344 prototype_stats_list.append(proto_stats)
346 # Analyze confidence
347 probs = F.softmax(logits, dim=-1)
348 max_probs = probs.max(dim=-1)[0]
349 confidence_scores.extend(max_probs.tolist())
351 except Exception as e:
352 logger.warning(f"Episode {episode} analysis failed: {e}")
353 continue
355 # Aggregate results
356 analysis = {
357 'accuracy_stats': {
358 'mean': np.mean(episode_accuracies),
359 'std': np.std(episode_accuracies),
360 'min': np.min(episode_accuracies),
361 'max': np.max(episode_accuracies),
362 'episodes': len(episode_accuracies)
363 },
364 'confidence_stats': {
365 'mean': np.mean(confidence_scores),
366 'std': np.std(confidence_scores),
367 'median': np.median(confidence_scores)
368 } if confidence_scores else None
369 }
371 # Prototype analysis
372 if prototype_stats_list:
373 proto_analysis = {}
374 for key in prototype_stats_list[0].keys():
375 values = [stats[key] for stats in prototype_stats_list]
376 proto_analysis[key] = {
377 'mean': np.mean(values),
378 'std': np.std(values)
379 }
380 analysis['prototype_stats'] = proto_analysis
382 return analysis
385def create_backbone_network(architecture: str = "conv4", input_channels: int = 3,
386 embedding_dim: int = 512) -> nn.Module:
387 """
388 Create a backbone network for few-shot learning.
390 Args:
391 architecture: Backbone architecture ('conv4', 'resnet', 'simple')
392 input_channels: Number of input channels
393 embedding_dim: Output embedding dimension
395 Returns:
396 Backbone network
397 """
398 if architecture == "conv4":
399 # Standard 4-layer CNN backbone used in few-shot learning
400 backbone = nn.Sequential(
401 # Layer 1
402 nn.Conv2d(input_channels, 64, 3, padding=1),
403 nn.BatchNorm2d(64),
404 nn.ReLU(inplace=True),
405 nn.MaxPool2d(2),
407 # Layer 2
408 nn.Conv2d(64, 64, 3, padding=1),
409 nn.BatchNorm2d(64),
410 nn.ReLU(inplace=True),
411 nn.MaxPool2d(2),
413 # Layer 3
414 nn.Conv2d(64, 64, 3, padding=1),
415 nn.BatchNorm2d(64),
416 nn.ReLU(inplace=True),
417 nn.MaxPool2d(2),
419 # Layer 4
420 nn.Conv2d(64, 64, 3, padding=1),
421 nn.BatchNorm2d(64),
422 nn.ReLU(inplace=True),
423 nn.MaxPool2d(2),
425 # Global average pooling
426 nn.AdaptiveAvgPool2d(1),
427 nn.Flatten(),
429 # Final projection to embedding dimension
430 nn.Linear(64, embedding_dim)
431 )
433 elif architecture == "simple":
434 # Simple backbone for educational purposes
435 backbone = nn.Sequential(
436 nn.Conv2d(input_channels, 32, 3, padding=1),
437 nn.ReLU(inplace=True),
438 nn.MaxPool2d(2),
439 nn.Conv2d(32, 64, 3, padding=1),
440 nn.ReLU(inplace=True),
441 nn.MaxPool2d(2),
442 nn.AdaptiveAvgPool2d(1),
443 nn.Flatten(),
444 nn.Linear(64, embedding_dim)
445 )
447 else:
448 raise ValueError(f"Unknown backbone architecture: {architecture}")
450 return backbone