Coverage for src/meta_learning/meta_learning_modules/utils.py: 21%
649 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
1"""
2🧰 Meta-Learning Utilities - Research-Grade Helper Functions
3===========================================================
5Author: Benedict Chen (benedict@benedictchen.com)
7💰 Donations: Help support this research!
8 PayPal: https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=WXQKYYKPHWXHS
9 💖 Please consider recurring donations to support continued meta-learning research
11This module provides research-accurate utilities for meta-learning that fill
12critical gaps in existing libraries (learn2learn, torchmeta, higher) and
13provide statistically rigorous functionality for proper scientific evaluation.
15🔬 Research Foundation:
16======================
17Implements utilities supporting core meta-learning research:
18- Hospedales et al. (2021): Meta-learning statistical evaluation protocols
19- Chen et al. (2019): Closer look at few-shot classification benchmarking
20- Triantafillou et al. (2020): Meta-Dataset evaluation methodology
21- Gidaris & Komodakis (2019): Dynamic few-shot visual classification
23🎯 Key Utility Categories:
24=========================
251. **Dataset & Task Sampling**: Research-accurate task generation with difficulty control
262. **Statistical Evaluation**: Proper confidence intervals following meta-learning protocols
273. **Benchmarking Tools**: Fair comparison methodology across algorithms
284. **Data Augmentation**: Meta-learning specific augmentation strategies
295. **Analysis & Visualization**: Research-grade plots and statistical analysis
31ELI5 Explanation:
32================
33Think of this module like a Swiss Army knife for meta-learning research! 🔧
35Just like a Swiss Army knife has all the small tools you need for camping
36(bottle opener, small knife, screwdriver), this module has all the small
37but essential tools you need for meta-learning research:
39🎲 **Task Generators**: Create fair "learning challenges" for your algorithms
40📊 **Statistical Tools**: Make sure your results are scientifically reliable
41📈 **Benchmarking**: Compare algorithms fairly (like timing runners on the same track)
42🔍 **Analysis Tools**: Understand what your algorithms are actually learning
44Without these utilities, doing meta-learning research would be like trying
45to fix a watch with just a hammer - you need the right specialized tools!
47ASCII Utility Architecture:
48===========================
49 Raw Data Task Generator Meta-Learning
50 ┌─────────┐ ┌─────────────┐ Episodes
51 │ Images │────▶│ Sample N-way│────▶┌─────────────┐
52 │ Labels │ │ K-shot tasks│ │Support: 5x5 │
53 └─────────┘ └─────────────┘ │Query: 5x15 │
54 │ │ └─────────────┘
55 │ ▼ │
56 │ ┌─────────────┐ ▼
57 └────────│Statistical │ ┌─────────────┐
58 │Analyzer │◀──────│Algorithm │
59 │- CI calc │ │Performance │
60 │- Significance│ │Metrics │
61 └─────────────┘ └─────────────┘
62 │ │
63 ▼ ▼
64 ┌─────────────┐ ┌─────────────┐
65 │Research │ │Visualization│
66 │Report │◀──────│& Analysis │
67 │Generator │ │Tools │
68 └─────────────┘ └─────────────┘
70⚡ Core Components:
71==================
721. **MetaLearningDataset**: Generates episodic tasks with proper statistics
732. **TaskConfiguration**: Controls N-way K-shot sampling with difficulty metrics
743. **EvaluationConfig**: Statistical evaluation following research protocols
754. **ConfidenceIntervals**: Research-accurate CI computation (4 methods available)
765. **BenchmarkSuite**: Fair algorithm comparison with statistical rigor
78📊 Statistical Rigor Features:
79=============================
80• **Multiple CI Methods**: Bootstrap, t-distribution, BCa bootstrap, meta-learning standard
81• **Proper Episode Sampling**: Stratified sampling preserving class distributions
82• **Difficulty Estimation**: 4 methods (silhouette, entropy, KNN, pairwise distance)
83• **Statistical Testing**: Significance tests between algorithm performances
84• **Research Protocols**: 600-episode evaluation following Hospedales et al. (2021)
86This module transforms ad-hoc meta-learning experiments into rigorous,
87reproducible scientific research with proper statistical foundations.
88"""
90import torch
91import torch.nn as nn
92import torch.nn.functional as F
93from torch.utils.data import Dataset, DataLoader, Sampler
94from typing import Dict, List, Tuple, Optional, Any, Iterator, Union, Callable
95import numpy as np
96import random
97import logging
98from collections import defaultdict, Counter
99import matplotlib.pyplot as plt
100import seaborn as sns
101from dataclasses import dataclass
102import json
103import pickle
104from pathlib import Path
106logger = logging.getLogger(__name__)
109@dataclass
110class TaskConfiguration:
111 """Configuration for meta-learning tasks."""
112 n_way: int = 5
113 k_shot: int = 5
114 q_query: int = 15
115 num_tasks: int = 1000
116 task_type: str = "classification"
117 augmentation_strategy: str = "basic" # basic, advanced, none
119 # FIXME SOLUTION: Configuration options for difficulty estimation methods
120 difficulty_estimation_method: str = "pairwise_distance" # "pairwise_distance", "silhouette", "entropy", "knn"
121 use_research_accurate_difficulty: bool = False # Enable research-backed methods
124@dataclass
125class EvaluationConfig:
126 """Configuration for meta-learning evaluation."""
127 confidence_intervals: bool = True
128 num_bootstrap_samples: int = 1000
129 significance_level: float = 0.05
130 track_adaptation_curve: bool = True
131 compute_uncertainty: bool = True
133 # FIXME SOLUTION: Configuration options for confidence interval methods
134 ci_method: str = "bootstrap" # "bootstrap", "t_distribution", "meta_learning_standard", "bca_bootstrap"
135 use_research_accurate_ci: bool = False # Enable research-backed CI methods
136 num_episodes: int = 600 # Standard meta-learning evaluation protocol
138 # Additional configuration for advanced CI methods
139 min_sample_size_for_bootstrap: int = 30 # Minimum sample size for bootstrap vs t-distribution
140 auto_method_selection: bool = True # Automatically select best CI method based on data
143class MetaLearningDataset(Dataset):
144 """
145 Advanced Meta-Learning Dataset with sophisticated task sampling.
147 Key improvements over existing libraries:
148 1. Hierarchical task organization with difficulty levels
149 2. Balanced task sampling across domains and difficulties
150 3. Dynamic task generation with curriculum learning
151 4. Advanced data augmentation strategies for meta-learning
152 5. Task similarity tracking and diverse sampling
153 """
155 def __init__(
156 self,
157 data: torch.Tensor,
158 labels: torch.Tensor,
159 task_config: TaskConfiguration = None,
160 class_names: Optional[List[str]] = None,
161 domain_labels: Optional[torch.Tensor] = None
162 ):
163 """
164 Initialize Meta-Learning Dataset.
166 Args:
167 data: Input data [n_samples, ...]
168 labels: Class labels [n_samples]
169 task_config: Task configuration
170 class_names: Optional class names for interpretability
171 domain_labels: Optional domain labels for cross-domain tasks
172 """
173 self.data = data
174 self.labels = labels
175 self.config = task_config or TaskConfiguration()
176 self.class_names = class_names
177 self.domain_labels = domain_labels
179 # Organize data by class for efficient sampling
180 self.class_to_indices = defaultdict(list)
181 for idx, label in enumerate(labels):
182 self.class_to_indices[label.item()].append(idx)
184 self.unique_classes = list(self.class_to_indices.keys())
185 self.num_classes = len(self.unique_classes)
187 # Task history for diversity tracking
188 self.task_history = []
189 self.class_usage_count = Counter()
191 # Difficulty estimation using configured method
192 if self.config.use_research_accurate_difficulty:
193 self.class_difficulties = self._estimate_class_difficulties_research_accurate()
194 else:
195 self.class_difficulties = self._estimate_class_difficulties()
197 logger.info(f"Initialized MetaLearningDataset: {self.num_classes} classes, {len(data)} samples")
199 def __len__(self) -> int:
200 """Number of possible tasks (virtually infinite for meta-learning)."""
201 return self.config.num_tasks
203 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
204 """
205 Sample a meta-learning task.
207 Returns:
208 Dictionary containing support and query sets with labels
209 """
210 task = self.sample_task(task_idx=idx)
211 return task
213 def sample_task(
214 self,
215 task_idx: Optional[int] = None,
216 specified_classes: Optional[List[int]] = None,
217 difficulty_level: Optional[str] = None
218 ) -> Dict[str, torch.Tensor]:
219 """
220 Sample a single meta-learning task with advanced strategies.
222 Args:
223 task_idx: Optional task index for reproducibility
224 specified_classes: Specific classes to use (overrides sampling)
225 difficulty_level: "easy", "medium", "hard", or None for automatic
227 Returns:
228 Task dictionary with support/query sets and metadata
229 """
230 # Set random seed for reproducible task sampling
231 if task_idx is not None:
232 torch.manual_seed(42 + task_idx)
233 np.random.seed(42 + task_idx)
235 # Select classes for this task
236 if specified_classes:
237 task_classes = specified_classes
238 else:
239 task_classes = self._sample_task_classes(difficulty_level)
241 # Sample support and query sets
242 support_data, support_labels, query_data, query_labels = self._sample_support_query(
243 task_classes
244 )
246 # Apply data augmentation
247 if self.config.augmentation_strategy != "none":
248 support_data = self._apply_augmentation(support_data, self.config.augmentation_strategy)
250 # Update task history and class usage
251 self.task_history.append(task_classes)
252 for class_id in task_classes:
253 self.class_usage_count[class_id] += 1
255 # Compute task metadata
256 task_metadata = self._compute_task_metadata(task_classes, support_labels, query_labels)
258 return {
259 "support": {
260 "data": support_data,
261 "labels": support_labels
262 },
263 "query": {
264 "data": query_data,
265 "labels": query_labels
266 },
267 "task_classes": torch.tensor(task_classes),
268 "metadata": task_metadata
269 }
271 def _sample_task_classes(self, difficulty_level: Optional[str] = None) -> List[int]:
272 """Sample classes for a task with diversity and difficulty control."""
273 if difficulty_level:
274 # Filter classes by difficulty
275 if difficulty_level == "easy":
276 candidate_classes = [c for c in self.unique_classes
277 if self.class_difficulties[c] < 0.3]
278 elif difficulty_level == "medium":
279 candidate_classes = [c for c in self.unique_classes
280 if 0.3 <= self.class_difficulties[c] < 0.7]
281 elif difficulty_level == "hard":
282 candidate_classes = [c for c in self.unique_classes
283 if self.class_difficulties[c] >= 0.7]
284 else:
285 candidate_classes = self.unique_classes
286 else:
287 candidate_classes = self.unique_classes
289 # Ensure we have enough classes
290 if len(candidate_classes) < self.config.n_way:
291 candidate_classes = self.unique_classes
293 # Diversity-aware sampling (prefer less used classes)
294 class_weights = []
295 for class_id in candidate_classes:
296 # Inverse frequency weighting for diversity
297 usage_count = self.class_usage_count.get(class_id, 0)
298 weight = 1.0 / (1.0 + usage_count)
299 class_weights.append(weight)
301 # Normalize weights
302 class_weights = np.array(class_weights)
303 class_weights = class_weights / class_weights.sum()
305 # Sample classes
306 selected_indices = np.random.choice(
307 len(candidate_classes),
308 size=self.config.n_way,
309 replace=False,
310 p=class_weights
311 )
313 return [candidate_classes[i] for i in selected_indices]
315 def _sample_support_query(
316 self,
317 task_classes: List[int]
318 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
319 """Sample support and query sets for given classes."""
320 support_data = []
321 support_labels = []
322 query_data = []
323 query_labels = []
325 for new_label, original_class in enumerate(task_classes):
326 # Get indices for this class
327 class_indices = self.class_to_indices[original_class]
329 # Ensure we have enough samples
330 total_needed = self.config.k_shot + self.config.q_query
331 if len(class_indices) < total_needed:
332 # Sample with replacement if necessary
333 selected_indices = np.random.choice(
334 class_indices, size=total_needed, replace=True
335 )
336 else:
337 selected_indices = np.random.choice(
338 class_indices, size=total_needed, replace=False
339 )
341 # Split into support and query
342 support_indices = selected_indices[:self.config.k_shot]
343 query_indices = selected_indices[self.config.k_shot:]
345 # Collect support set
346 for idx in support_indices:
347 support_data.append(self.data[idx])
348 support_labels.append(new_label)
350 # Collect query set
351 for idx in query_indices:
352 query_data.append(self.data[idx])
353 query_labels.append(new_label)
355 return (
356 torch.stack(support_data),
357 torch.tensor(support_labels),
358 torch.stack(query_data),
359 torch.tensor(query_labels)
360 )
362 def _estimate_class_difficulties(self) -> Dict[int, float]:
363 """
364 Estimate difficulty of each class based on intra-class variance.
366 FIXME RESEARCH ACCURACY ISSUES:
367 1. ARBITRARY DIFFICULTY METRIC: No research basis for using mean pairwise distance as difficulty
368 2. INEFFICIENT COMPUTATION: O(n²) complexity for pairwise distance calculation
369 3. MISSING ESTABLISHED METRICS: Should use research-validated difficulty measures
370 4. NO COMPARISON TO BASELINES: Not comparing to standard difficulty estimation methods
372 BETTER APPROACHES from research:
373 """
374 difficulties = {}
376 for class_id, indices in self.class_to_indices.items():
377 if len(indices) > 1:
378 class_data = self.data[indices]
380 # CURRENT (PROBLEMATIC): Arbitrary pairwise distance measure
381 flattened_data = class_data.view(len(class_data), -1)
382 distances = torch.cdist(flattened_data, flattened_data)
383 mean_distance = distances.sum() / (len(distances) ** 2 - len(distances))
384 difficulties[class_id] = mean_distance.item()
385 else:
386 difficulties[class_id] = 0.5 # Default medium difficulty
388 # Normalize difficulties to [0, 1]
389 if difficulties:
390 max_diff = max(difficulties.values())
391 min_diff = min(difficulties.values())
392 if max_diff > min_diff:
393 for class_id in difficulties:
394 difficulties[class_id] = (difficulties[class_id] - min_diff) / (max_diff - min_diff)
396 return difficulties
398 def _estimate_class_difficulties_research_accurate(self) -> Dict[int, float]:
399 """
400 Route to appropriate research-accurate difficulty estimation method based on configuration.
401 """
402 if self.config.difficulty_estimation_method == "silhouette":
403 return self._estimate_class_difficulty_silhouette()
404 elif self.config.difficulty_estimation_method == "entropy":
405 return self._estimate_class_difficulty_entropy()
406 elif self.config.difficulty_estimation_method == "knn":
407 return self._estimate_class_difficulty_knn()
408 else: # default to pairwise_distance
409 return self._estimate_class_difficulties()
411 def _estimate_class_difficulty_silhouette(self) -> Dict[int, float]:
412 """
413 FIXME SOLUTION 1: Use Silhouette Score for class difficulty estimation.
415 Based on "Silhouette: a graphical aid to the interpretation and validation of cluster analysis" (1987)
416 Silhouette score measures how well-separated classes are.
417 """
418 from sklearn.metrics import silhouette_samples
420 difficulties = {}
421 all_data = self.data.view(len(self.data), -1).numpy()
422 all_labels = self.labels.numpy()
424 # Compute silhouette scores for all samples
425 silhouette_scores = silhouette_samples(all_data, all_labels)
427 # Average silhouette score per class (lower = more difficult)
428 for class_id in self.unique_classes:
429 class_mask = all_labels == class_id
430 class_silhouette = silhouette_scores[class_mask].mean()
432 # Convert to difficulty (1 - silhouette, normalized to [0, 1])
433 difficulties[class_id] = 1.0 - (class_silhouette + 1.0) / 2.0
435 return difficulties
437 def _estimate_class_difficulty_entropy(self) -> Dict[int, float]:
438 """
439 FIXME SOLUTION 2: Use feature entropy for difficulty estimation.
441 Classes with higher feature entropy are typically more difficult.
442 Common approach in few-shot learning literature.
443 """
444 difficulties = {}
446 for class_id, indices in self.class_to_indices.items():
447 if len(indices) > 1:
448 class_data = self.data[indices]
450 # Compute feature-wise entropy
451 flattened_data = class_data.view(len(class_data), -1)
453 # Discretize features for entropy calculation
454 discretized = torch.floor(flattened_data * 10) / 10 # Simple binning
456 # Compute entropy for each feature dimension
457 entropies = []
458 for feature_dim in range(discretized.shape[1]):
459 feature_values = discretized[:, feature_dim]
460 unique_vals, counts = torch.unique(feature_values, return_counts=True)
461 probs = counts.float() / len(feature_values)
462 entropy = -torch.sum(probs * torch.log(probs + 1e-8))
463 entropies.append(entropy.item())
465 # Average entropy as difficulty measure
466 difficulties[class_id] = np.mean(entropies)
467 else:
468 difficulties[class_id] = 0.5
470 return difficulties
472 def _estimate_class_difficulty_knn(self) -> Dict[int, float]:
473 """
474 FIXME SOLUTION 3: Use k-NN classification accuracy for difficulty estimation.
476 Based on the intuition that harder classes have lower k-NN accuracy.
477 Well-established in machine learning literature.
478 """
479 from sklearn.neighbors import KNeighborsClassifier
480 from sklearn.model_selection import cross_val_score
482 difficulties = {}
484 # For each class, measure how well k-NN can distinguish it from others
485 for class_id in self.unique_classes:
486 # Create binary classification problem: current class vs all others
487 class_mask = self.labels == class_id
488 binary_labels = class_mask.long()
490 # Prepare data
491 X = self.data.view(len(self.data), -1).numpy()
492 y = binary_labels.numpy()
494 # k-NN classification
495 knn = KNeighborsClassifier(n_neighbors=5)
496 scores = cross_val_score(knn, X, y, cv=3, scoring='accuracy')
498 # Lower accuracy = higher difficulty
499 difficulties[class_id] = 1.0 - scores.mean()
501 return difficulties
503 def _apply_augmentation(self, data: torch.Tensor, strategy: str) -> torch.Tensor:
504 """Apply data augmentation strategies optimized for meta-learning."""
505 if strategy == "basic":
506 return self._basic_augmentation(data)
507 elif strategy == "advanced":
508 return self._advanced_augmentation(data)
509 else:
510 return data
512 def _basic_augmentation(self, data: torch.Tensor) -> torch.Tensor:
513 """Basic augmentation: random noise and small rotations."""
514 # Add random noise
515 noise_std = 0.01
516 noise = torch.randn_like(data) * noise_std
517 augmented = data + noise
519 return torch.clamp(augmented, 0, 1) # Assume data is normalized to [0, 1]
521 def _advanced_augmentation(self, data: torch.Tensor) -> torch.Tensor:
522 """Advanced augmentation with meta-learning specific techniques."""
523 # Meta-learning specific augmentation that preserves task structure
524 # while adding beneficial variance
526 # 1. Support set mixing (mix examples within the same class)
527 augmented = data.clone()
529 # 2. Add calibrated noise based on data statistics
530 data_std = data.std(dim=0, keepdim=True)
531 noise = torch.randn_like(data) * (data_std * 0.05)
532 augmented = augmented + noise
534 # 3. Random feature masking (for structured data)
535 if len(data.shape) > 2: # Multi-dimensional features
536 mask_prob = 0.1
537 mask = torch.rand_like(data) > mask_prob
538 augmented = augmented * mask
540 return torch.clamp(augmented, 0, 1)
542 def _compute_task_metadata(
543 self,
544 task_classes: List[int],
545 support_labels: torch.Tensor,
546 query_labels: torch.Tensor
547 ) -> Dict[str, Any]:
548 """Compute metadata for the sampled task."""
549 metadata = {
550 "n_way": len(task_classes),
551 "k_shot": self.config.k_shot,
552 "q_query": self.config.q_query,
553 "task_classes": task_classes,
554 "class_difficulties": [self.class_difficulties[c] for c in task_classes],
555 "avg_difficulty": np.mean([self.class_difficulties[c] for c in task_classes])
556 }
558 # Add class names if available
559 if self.class_names:
560 metadata["class_names"] = [self.class_names[c] for c in task_classes]
562 return metadata
565class TaskSampler(Sampler):
566 """
567 Advanced Task Sampler for meta-learning with curriculum learning support.
569 Key features not found in existing libraries:
570 1. Curriculum learning with difficulty progression
571 2. Balanced sampling across task types and difficulties
572 3. Anti-correlation sampling to ensure task diversity
573 4. Adaptive batch composition based on performance
574 """
576 def __init__(
577 self,
578 dataset: MetaLearningDataset,
579 batch_size: int = 16,
580 curriculum_learning: bool = True,
581 difficulty_schedule: str = "linear" # linear, exponential, adaptive
582 ):
583 """
584 Initialize Task Sampler.
586 Args:
587 dataset: MetaLearningDataset to sample from
588 batch_size: Number of tasks per batch
589 curriculum_learning: Whether to use curriculum learning
590 difficulty_schedule: How difficulty progresses over training
591 """
592 self.dataset = dataset
593 self.batch_size = batch_size
594 self.curriculum_learning = curriculum_learning
595 self.difficulty_schedule = difficulty_schedule
597 # Curriculum state
598 self.current_epoch = 0
599 self.total_epochs = 1000 # Will be updated during training
600 self.difficulty_level = 0.0 # 0.0 = easiest, 1.0 = hardest
602 # Performance tracking for adaptive curriculum
603 self.performance_history = []
605 logger.info(f"Initialized TaskSampler: batch_size={batch_size}, curriculum={curriculum_learning}")
607 def __iter__(self) -> Iterator[List[int]]:
608 """Generate batches of task indices."""
609 n = len(self.dataset)
611 # Generate task indices
612 indices = list(range(n))
614 # Curriculum learning: filter by difficulty
615 if self.curriculum_learning:
616 indices = self._apply_curriculum_filter(indices)
618 # Shuffle for randomness
619 random.shuffle(indices)
621 # Generate batches
622 for i in range(0, len(indices), self.batch_size):
623 batch_indices = indices[i:i + self.batch_size]
624 if len(batch_indices) == self.batch_size: # Only yield full batches
625 yield batch_indices
627 def __len__(self) -> int:
628 """Number of batches per epoch."""
629 effective_size = len(self.dataset)
630 if self.curriculum_learning:
631 # Account for curriculum filtering
632 effective_size = int(effective_size * min(1.0, 0.1 + 0.9 * self.difficulty_level))
633 return effective_size // self.batch_size
635 def update_epoch(self, epoch: int, total_epochs: int):
636 """Update curriculum state for new epoch."""
637 self.current_epoch = epoch
638 self.total_epochs = total_epochs
640 # Update difficulty level based on schedule
641 if self.difficulty_schedule == "linear":
642 self.difficulty_level = epoch / total_epochs
643 elif self.difficulty_schedule == "exponential":
644 self.difficulty_level = (np.exp(epoch / total_epochs) - 1) / (np.e - 1)
645 elif self.difficulty_schedule == "adaptive":
646 self.difficulty_level = self._adaptive_difficulty_schedule()
648 self.difficulty_level = np.clip(self.difficulty_level, 0.0, 1.0)
650 logger.debug(f"Epoch {epoch}: difficulty_level = {self.difficulty_level:.3f}")
652 def _apply_curriculum_filter(self, indices: List[int]) -> List[int]:
653 """Filter task indices based on current curriculum difficulty."""
654 # This is a simplified version - in practice would use actual task difficulties
655 # For now, include a fraction of tasks based on difficulty level
656 fraction_to_include = 0.1 + 0.9 * self.difficulty_level
657 num_to_include = int(len(indices) * fraction_to_include)
659 return indices[:num_to_include]
661 def _adaptive_difficulty_schedule(self) -> float:
662 """Compute adaptive difficulty based on recent performance."""
663 if len(self.performance_history) < 10:
664 # Not enough data, use linear schedule
665 return self.current_epoch / self.total_epochs
667 # Compute recent performance trend
668 recent_performance = self.performance_history[-10:]
669 performance_mean = np.mean(recent_performance)
670 performance_trend = np.mean(np.diff(recent_performance))
672 # Adapt difficulty based on performance
673 base_difficulty = self.current_epoch / self.total_epochs
675 if performance_mean > 0.8 and performance_trend > 0:
676 # High performance and improving - increase difficulty faster
677 adaptation = min(0.2, performance_trend * 5)
678 elif performance_mean < 0.6 and performance_trend < 0:
679 # Low performance and declining - slow down difficulty increase
680 adaptation = max(-0.1, performance_trend * 2)
681 else:
682 adaptation = 0
684 return np.clip(base_difficulty + adaptation, 0.0, 1.0)
686 def update_performance(self, accuracy: float):
687 """Update performance history for adaptive curriculum."""
688 self.performance_history.append(accuracy)
690 # Keep only recent history
691 if len(self.performance_history) > 100:
692 self.performance_history = self.performance_history[-100:]
695def few_shot_accuracy(
696 predictions: torch.Tensor,
697 targets: torch.Tensor,
698 return_per_class: bool = False
699) -> Union[float, Tuple[float, torch.Tensor]]:
700 """
701 Compute few-shot learning accuracy with advanced metrics.
703 Args:
704 predictions: Model predictions [n_samples, n_classes] or [n_samples]
705 targets: Ground truth labels [n_samples]
706 return_per_class: Whether to return per-class accuracies
708 Returns:
709 Overall accuracy, optionally with per-class accuracies
710 """
711 if predictions.dim() == 2:
712 # Logits or probabilities - take argmax
713 pred_labels = predictions.argmax(dim=-1)
714 else:
715 # Already labels
716 pred_labels = predictions
718 # Overall accuracy
719 correct = (pred_labels == targets).float()
720 overall_accuracy = correct.mean().item()
722 if return_per_class:
723 # Per-class accuracy
724 unique_classes = torch.unique(targets)
725 per_class_accuracies = []
727 for class_id in unique_classes:
728 class_mask = targets == class_id
729 if class_mask.sum() > 0:
730 class_correct = correct[class_mask].mean().item()
731 per_class_accuracies.append(class_correct)
732 else:
733 per_class_accuracies.append(0.0)
735 return overall_accuracy, torch.tensor(per_class_accuracies)
737 return overall_accuracy
740def adaptation_speed(
741 loss_curve: List[float],
742 convergence_threshold: float = 0.01
743) -> Tuple[int, float]:
744 """
745 Measure adaptation speed for meta-learning algorithms.
747 Args:
748 loss_curve: List of losses during adaptation steps
749 convergence_threshold: Threshold for considering convergence
751 Returns:
752 Tuple of (steps_to_convergence, final_loss)
753 """
754 if len(loss_curve) < 2:
755 return len(loss_curve), loss_curve[-1] if loss_curve else float('inf')
757 # Find convergence point
758 for i in range(1, len(loss_curve)):
759 loss_change = abs(loss_curve[i] - loss_curve[i-1])
760 if loss_change < convergence_threshold:
761 return i + 1, loss_curve[i]
763 # No convergence found
764 return len(loss_curve), loss_curve[-1]
767def compute_confidence_interval(
768 values: List[float],
769 confidence_level: float = 0.95,
770 num_bootstrap: int = 1000
771) -> Tuple[float, float, float]:
772 """
773 Compute confidence interval using bootstrap sampling.
775 FIXME RESEARCH ACCURACY ISSUES:
776 1. BOOTSTRAP ONLY: Should also offer t-distribution CI for small samples (n < 30)
777 2. MISSING VALIDATION: No check for minimum sample size for valid bootstrap
778 3. NO BIAS CORRECTION: Should implement bias-corrected and accelerated (BCa) bootstrap
779 4. MISSING STANDARD REPORTING: Meta-learning literature typically uses specific CI methods
781 CORRECT APPROACHES:
782 - t-distribution CI for small samples
783 - BCa bootstrap for better accuracy
784 - Standard meta-learning evaluation protocols
786 Args:
787 values: List of values to compute CI for
788 confidence_level: Confidence level (e.g., 0.95 for 95%)
789 num_bootstrap: Number of bootstrap samples
791 Returns:
792 Tuple of (mean, lower_bound, upper_bound)
793 """
794 if len(values) == 0:
795 return 0.0, 0.0, 0.0
797 values = np.array(values)
798 mean_val = np.mean(values)
800 # Check sample size and use appropriate method
801 if len(values) < 30:
802 # Use t-distribution CI for small samples (research-accurate)
803 from scipy import stats
804 t_critical = stats.t.ppf((1 + confidence_level) / 2, df=len(values) - 1)
805 standard_error = np.std(values, ddof=1) / np.sqrt(len(values))
806 margin_of_error = t_critical * standard_error
808 ci_lower = mean_val - margin_of_error
809 ci_upper = mean_val + margin_of_error
811 logger.debug(f"Used t-distribution CI for small sample (n={len(values)})")
812 return mean_val, ci_lower, ci_upper
814 # CURRENT: Basic bootstrap (adequate but not optimal)
815 bootstrap_means = []
816 for _ in range(num_bootstrap):
817 bootstrap_sample = np.random.choice(values, size=len(values), replace=True)
818 bootstrap_means.append(np.mean(bootstrap_sample))
820 # Compute percentiles
821 alpha = 1 - confidence_level
822 lower_percentile = (alpha / 2) * 100
823 upper_percentile = (1 - alpha / 2) * 100
825 lower_bound = np.percentile(bootstrap_means, lower_percentile)
826 upper_bound = np.percentile(bootstrap_means, upper_percentile)
828 return mean_val, lower_bound, upper_bound
831def compute_confidence_interval_research_accurate(
832 values: List[float],
833 config: EvaluationConfig = None,
834 confidence_level: float = 0.95
835) -> Tuple[float, float, float]:
836 """
837 FIXME SOLUTION: Compute confidence interval using configured research-accurate method.
839 Uses appropriate CI method based on configuration and sample size with auto-selection.
840 """
841 config = config or EvaluationConfig()
843 if not config.use_research_accurate_ci:
844 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples)
846 # Auto-select method if enabled
847 if config.auto_method_selection:
848 method = _auto_select_ci_method(values, config)
849 else:
850 method = config.ci_method
852 # Route to appropriate method based on configuration
853 if method == "t_distribution":
854 return compute_t_confidence_interval(values, confidence_level)
855 elif method == "meta_learning_standard":
856 return compute_meta_learning_ci(values, confidence_level, config.num_episodes)
857 elif method == "bca_bootstrap":
858 return compute_bca_bootstrap_ci(values, confidence_level, config.num_bootstrap_samples)
859 else: # bootstrap
860 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples)
862def _auto_select_ci_method(values: List[float], config: EvaluationConfig) -> str:
863 """
864 Automatically select the best CI method based on data characteristics.
866 Selection criteria based on statistical best practices:
867 - t-distribution for small samples (n < 30)
868 - Bootstrap for moderate samples (30 <= n < 100)
869 - BCa bootstrap for large samples (n >= 100) or skewed distributions
870 - Meta-learning standard for exactly 600 episodes (standard protocol)
871 """
872 n = len(values)
874 # Standard meta-learning evaluation protocol
875 if n == config.num_episodes:
876 return "meta_learning_standard"
878 # Small sample: use t-distribution
879 if n < config.min_sample_size_for_bootstrap:
880 return "t_distribution"
882 # Large sample or check for skewness
883 if n >= 100:
884 # Check for skewness (simple heuristic)
885 values_array = np.array(values)
886 mean_val = np.mean(values_array)
887 median_val = np.median(values_array)
889 # If distribution is skewed, use BCa bootstrap
890 skew_threshold = 0.1 * np.std(values_array)
891 if abs(mean_val - median_val) > skew_threshold:
892 return "bca_bootstrap"
894 # Default to standard bootstrap for moderate samples
895 return "bootstrap"
897# FIXME SOLUTION 1: t-distribution confidence interval for small samples
898def compute_t_confidence_interval(
899 values: List[float],
900 confidence_level: float = 0.95
901) -> Tuple[float, float, float]:
902 """
903 Compute confidence interval using t-distribution (appropriate for small samples).
905 Standard approach in meta-learning evaluation when n < 30.
906 """
907 import scipy.stats as stats
909 if len(values) == 0:
910 return 0.0, 0.0, 0.0
912 values = np.array(values)
913 mean_val = np.mean(values)
914 std_val = np.std(values, ddof=1) # Sample standard deviation
915 n = len(values)
917 # Degrees of freedom
918 df = n - 1
920 # Critical t-value
921 alpha = 1 - confidence_level
922 t_critical = stats.t.ppf(1 - alpha/2, df)
924 # Margin of error
925 margin_error = t_critical * (std_val / np.sqrt(n))
927 # Confidence interval
928 lower_bound = mean_val - margin_error
929 upper_bound = mean_val + margin_error
931 return mean_val, lower_bound, upper_bound
933# FIXME SOLUTION 2: Meta-learning standard evaluation CI
934def compute_meta_learning_ci(
935 accuracies: List[float],
936 confidence_level: float = 0.95,
937 num_episodes: int = 600
938) -> Tuple[float, float, float]:
939 """
940 Standard confidence interval computation for meta-learning evaluation.
942 Based on standard protocols from few-shot learning literature:
943 - Vinyals et al. (2016): "Matching Networks"
944 - Snell et al. (2017): "Prototypical Networks"
945 - Finn et al. (2017): "MAML"
947 Typically uses 600 episodes with t-distribution CI.
948 """
949 if len(accuracies) != num_episodes:
950 print(f"Warning: Expected {num_episodes} episodes, got {len(accuracies)}")
952 # Use t-distribution for proper meta-learning evaluation
953 return compute_t_confidence_interval(accuracies, confidence_level)
955# FIXME SOLUTION 3: BCa (Bias-Corrected and Accelerated) Bootstrap
956def compute_bca_bootstrap_ci(
957 values: List[float],
958 confidence_level: float = 0.95,
959 num_bootstrap: int = 2000
960) -> Tuple[float, float, float]:
961 """
962 Bias-corrected and accelerated bootstrap confidence interval.
964 More accurate than basic bootstrap, especially for skewed distributions.
965 Based on Efron & Tibshirani (1993) "An Introduction to the Bootstrap".
966 """
967 import scipy.stats as stats
969 if len(values) == 0:
970 return 0.0, 0.0, 0.0
972 values = np.array(values)
973 n = len(values)
974 mean_val = np.mean(values)
976 # Bootstrap resampling
977 bootstrap_means = []
978 for _ in range(num_bootstrap):
979 bootstrap_sample = np.random.choice(values, size=n, replace=True)
980 bootstrap_means.append(np.mean(bootstrap_sample))
982 bootstrap_means = np.array(bootstrap_means)
984 # Bias correction
985 bias_correction = stats.norm.ppf((bootstrap_means < mean_val).mean())
987 # Acceleration (jackknife)
988 jackknife_means = []
989 for i in range(n):
990 jackknife_sample = np.concatenate([values[:i], values[i+1:]])
991 jackknife_means.append(np.mean(jackknife_sample))
993 jackknife_means = np.array(jackknife_means)
994 jackknife_mean = np.mean(jackknife_means)
996 acceleration = np.sum((jackknife_mean - jackknife_means)**3) / \
997 (6 * (np.sum((jackknife_mean - jackknife_means)**2))**(3/2))
999 # Adjusted percentiles
1000 alpha = 1 - confidence_level
1001 z_alpha_2 = stats.norm.ppf(alpha/2)
1002 z_1_alpha_2 = stats.norm.ppf(1 - alpha/2)
1004 alpha_1 = stats.norm.cdf(bias_correction +
1005 (bias_correction + z_alpha_2) / (1 - acceleration * (bias_correction + z_alpha_2)))
1006 alpha_2 = stats.norm.cdf(bias_correction +
1007 (bias_correction + z_1_alpha_2) / (1 - acceleration * (bias_correction + z_1_alpha_2)))
1009 # Compute bounds
1010 lower_bound = np.percentile(bootstrap_means, 100 * alpha_1)
1011 upper_bound = np.percentile(bootstrap_means, 100 * alpha_2)
1013 return mean_val, lower_bound, upper_bound
1016def visualize_meta_learning_results(
1017 results: Dict[str, List[float]],
1018 title: str = "Meta-Learning Results",
1019 save_path: Optional[str] = None
1020):
1021 """
1022 Create comprehensive visualizations for meta-learning results.
1024 Args:
1025 results: Dictionary with algorithm names as keys and accuracy lists as values
1026 title: Plot title
1027 save_path: Optional path to save the figure
1028 """
1029 fig, axes = plt.subplots(2, 2, figsize=(15, 12))
1030 fig.suptitle(title, fontsize=16)
1032 # 1. Accuracy comparison (box plot)
1033 ax1 = axes[0, 0]
1034 data_for_boxplot = [results[alg] for alg in results.keys()]
1035 labels = list(results.keys())
1037 ax1.boxplot(data_for_boxplot, labels=labels)
1038 ax1.set_title("Accuracy Distribution")
1039 ax1.set_ylabel("Accuracy")
1040 ax1.tick_params(axis='x', rotation=45)
1042 # 2. Learning curves
1043 ax2 = axes[0, 1]
1044 for alg_name, accuracies in results.items():
1045 # Compute running average
1046 running_avg = np.cumsum(accuracies) / np.arange(1, len(accuracies) + 1)
1047 ax2.plot(running_avg, label=alg_name, alpha=0.7)
1049 ax2.set_title("Learning Curves (Running Average)")
1050 ax2.set_xlabel("Task Number")
1051 ax2.set_ylabel("Cumulative Average Accuracy")
1052 ax2.legend()
1053 ax2.grid(True, alpha=0.3)
1055 # 3. Statistical comparison
1056 ax3 = axes[1, 0]
1057 means = [np.mean(results[alg]) for alg in results.keys()]
1058 stds = [np.std(results[alg]) for alg in results.keys()]
1060 ax3.barh(labels, means, xerr=stds, capsize=5)
1061 ax3.set_title("Mean Accuracy ± Standard Deviation")
1062 ax3.set_xlabel("Accuracy")
1064 # 4. Confidence intervals
1065 ax4 = axes[1, 1]
1066 ci_data = {}
1067 for alg_name, accuracies in results.items():
1068 mean_val, lower, upper = compute_confidence_interval(accuracies)
1069 ci_data[alg_name] = (mean_val, lower, upper)
1071 alg_names = list(ci_data.keys())
1072 means = [ci_data[alg][0] for alg in alg_names]
1073 lowers = [ci_data[alg][1] for alg in alg_names]
1074 uppers = [ci_data[alg][2] for alg in alg_names]
1076 y_pos = np.arange(len(alg_names))
1077 ax4.barh(y_pos, means, xerr=[np.array(means) - np.array(lowers),
1078 np.array(uppers) - np.array(means)],
1079 capsize=5)
1080 ax4.set_yticks(y_pos)
1081 ax4.set_yticklabels(alg_names)
1082 ax4.set_title("95% Confidence Intervals")
1083 ax4.set_xlabel("Accuracy")
1085 plt.tight_layout()
1087 if save_path:
1088 plt.savefig(save_path, dpi=300, bbox_inches='tight')
1089 logger.info(f"Saved visualization to {save_path}")
1091 plt.show()
1094def save_meta_learning_results(
1095 results: Dict[str, Any],
1096 filepath: str,
1097 format: str = "json"
1098):
1099 """
1100 Save meta-learning results to file.
1102 Args:
1103 results: Results dictionary to save
1104 filepath: Path to save file
1105 format: File format ("json", "pickle")
1106 """
1107 filepath = Path(filepath)
1108 filepath.parent.mkdir(parents=True, exist_ok=True)
1110 if format == "json":
1111 # Convert torch tensors to lists for JSON serialization
1112 serializable_results = {}
1113 for key, value in results.items():
1114 if isinstance(value, torch.Tensor):
1115 serializable_results[key] = value.tolist()
1116 elif isinstance(value, np.ndarray):
1117 serializable_results[key] = value.tolist()
1118 else:
1119 serializable_results[key] = value
1121 with open(filepath, 'w') as f:
1122 json.dump(serializable_results, f, indent=2)
1124 elif format == "pickle":
1125 with open(filepath, 'wb') as f:
1126 pickle.dump(results, f)
1128 logger.info(f"Saved results to {filepath}")
1131def load_meta_learning_results(filepath: str, format: str = "auto") -> Dict[str, Any]:
1132 """
1133 Load meta-learning results from file.
1135 Args:
1136 filepath: Path to load from
1137 format: File format ("json", "pickle", "auto")
1139 Returns:
1140 Loaded results dictionary
1141 """
1142 filepath = Path(filepath)
1144 if format == "auto":
1145 format = filepath.suffix[1:] # Remove the dot
1147 if format == "json":
1148 with open(filepath, 'r') as f:
1149 results = json.load(f)
1150 elif format in ["pickle", "pkl"]:
1151 with open(filepath, 'rb') as f:
1152 results = pickle.load(f)
1153 else:
1154 raise ValueError(f"Unsupported format: {format}")
1156 logger.info(f"Loaded results from {filepath}")
1157 return results
1160# =============================================================================
1161# FACTORY FUNCTIONS FOR EASY CONFIGURATION
1162# =============================================================================
1164def create_basic_task_config(n_way: int = 5, k_shot: int = 5, q_query: int = 15) -> TaskConfiguration:
1165 """Create basic task configuration with standard settings."""
1166 return TaskConfiguration(
1167 n_way=n_way,
1168 k_shot=k_shot,
1169 q_query=q_query,
1170 num_tasks=1000,
1171 task_type="classification",
1172 augmentation_strategy="basic",
1173 difficulty_estimation_method="pairwise_distance",
1174 use_research_accurate_difficulty=False
1175 )
1177def create_research_accurate_task_config(
1178 n_way: int = 5,
1179 k_shot: int = 5,
1180 q_query: int = 15,
1181 difficulty_method: str = "silhouette"
1182) -> TaskConfiguration:
1183 """Create research-accurate task configuration with proper difficulty estimation."""
1184 return TaskConfiguration(
1185 n_way=n_way,
1186 k_shot=k_shot,
1187 q_query=q_query,
1188 num_tasks=1000,
1189 task_type="classification",
1190 augmentation_strategy="advanced",
1191 difficulty_estimation_method=difficulty_method, # "silhouette", "entropy", "knn"
1192 use_research_accurate_difficulty=True
1193 )
1195def create_basic_evaluation_config() -> EvaluationConfig:
1196 """Create basic evaluation configuration with standard settings."""
1197 return EvaluationConfig(
1198 confidence_intervals=True,
1199 num_bootstrap_samples=1000,
1200 significance_level=0.05,
1201 track_adaptation_curve=True,
1202 compute_uncertainty=True,
1203 ci_method="bootstrap",
1204 use_research_accurate_ci=False,
1205 num_episodes=600,
1206 min_sample_size_for_bootstrap=30,
1207 auto_method_selection=False
1208 )
1210def create_research_accurate_evaluation_config(ci_method: str = "auto") -> EvaluationConfig:
1211 """Create research-accurate evaluation configuration with proper CI methods."""
1212 return EvaluationConfig(
1213 confidence_intervals=True,
1214 num_bootstrap_samples=2000, # Higher for better accuracy
1215 significance_level=0.05,
1216 track_adaptation_curve=True,
1217 compute_uncertainty=True,
1218 ci_method=ci_method, # "auto", "t_distribution", "meta_learning_standard", "bca_bootstrap"
1219 use_research_accurate_ci=True,
1220 num_episodes=600, # Standard meta-learning protocol
1221 min_sample_size_for_bootstrap=30,
1222 auto_method_selection=(ci_method == "auto")
1223 )
1225def create_meta_learning_standard_evaluation_config() -> EvaluationConfig:
1226 """Create evaluation configuration following standard meta-learning protocols."""
1227 return EvaluationConfig(
1228 confidence_intervals=True,
1229 num_bootstrap_samples=600, # Not used with t-distribution
1230 significance_level=0.05,
1231 track_adaptation_curve=True,
1232 compute_uncertainty=True,
1233 ci_method="meta_learning_standard",
1234 use_research_accurate_ci=True,
1235 num_episodes=600,
1236 min_sample_size_for_bootstrap=30,
1237 auto_method_selection=False
1238 )
1241# =============================================================================
1242# Missing Classes Implementation - Required by __init__.py imports
1243# =============================================================================
1245class DatasetConfig:
1246 """Configuration for meta-learning dataset creation."""
1248 def __init__(
1249 self,
1250 dataset_type: str = "episodic",
1251 augmentation_strategy: str = "minimal",
1252 shuffle: bool = True,
1253 stratified: bool = True,
1254 normalize: bool = True,
1255 cache_episodes: bool = False,
1256 **kwargs
1257 ):
1258 self.dataset_type = dataset_type
1259 self.augmentation_strategy = augmentation_strategy
1260 self.shuffle = shuffle
1261 self.stratified = stratified
1262 self.normalize = normalize
1263 self.cache_episodes = cache_episodes
1264 for key, value in kwargs.items():
1265 setattr(self, key, value)
1268class MetricsConfig:
1269 """Configuration for evaluation metrics computation."""
1271 def __init__(
1272 self,
1273 compute_accuracy: bool = True,
1274 compute_loss: bool = True,
1275 compute_adaptation_speed: bool = False,
1276 compute_uncertainty: bool = False,
1277 track_gradients: bool = False,
1278 save_predictions: bool = False,
1279 **kwargs
1280 ):
1281 self.compute_accuracy = compute_accuracy
1282 self.compute_loss = compute_loss
1283 self.compute_adaptation_speed = compute_adaptation_speed
1284 self.compute_uncertainty = compute_uncertainty
1285 self.track_gradients = track_gradients
1286 self.save_predictions = save_predictions
1287 for key, value in kwargs.items():
1288 setattr(self, key, value)
1291class StatsConfig:
1292 """Configuration for statistical analysis."""
1294 def __init__(
1295 self,
1296 confidence_level: float = 0.95,
1297 num_bootstrap_samples: int = 1000,
1298 significance_test: str = "t_test",
1299 multiple_comparison_correction: str = "bonferroni",
1300 effect_size_method: str = "cohen_d",
1301 **kwargs
1302 ):
1303 self.confidence_level = confidence_level
1304 self.num_bootstrap_samples = num_bootstrap_samples
1305 self.significance_test = significance_test
1306 self.multiple_comparison_correction = multiple_comparison_correction
1307 self.effect_size_method = effect_size_method
1308 for key, value in kwargs.items():
1309 setattr(self, key, value)
1312class CurriculumConfig:
1313 """Configuration for curriculum learning strategies."""
1315 def __init__(
1316 self,
1317 strategy: str = "difficulty_based",
1318 initial_difficulty: float = 0.3,
1319 difficulty_increment: float = 0.1,
1320 difficulty_threshold: float = 0.8,
1321 adaptation_patience: int = 5,
1322 **kwargs
1323 ):
1324 self.strategy = strategy
1325 self.initial_difficulty = initial_difficulty
1326 self.difficulty_increment = difficulty_increment
1327 self.difficulty_threshold = difficulty_threshold
1328 self.adaptation_patience = adaptation_patience
1329 for key, value in kwargs.items():
1330 setattr(self, key, value)
1333class DiversityConfig:
1334 """Configuration for task diversity tracking."""
1336 def __init__(
1337 self,
1338 diversity_metric: str = "cosine_similarity",
1339 track_class_distribution: bool = True,
1340 track_feature_diversity: bool = True,
1341 diversity_threshold: float = 0.7,
1342 **kwargs
1343 ):
1344 self.diversity_metric = diversity_metric
1345 self.track_class_distribution = track_class_distribution
1346 self.track_feature_diversity = track_feature_diversity
1347 self.diversity_threshold = diversity_threshold
1348 for key, value in kwargs.items():
1349 setattr(self, key, value)
1352class EvaluationMetrics:
1353 """Comprehensive evaluation metrics for meta-learning algorithms."""
1355 def __init__(self, config: MetricsConfig):
1356 self.config = config
1357 self.reset()
1359 def reset(self):
1360 """Reset all metrics to initial state."""
1361 self.accuracies = []
1362 self.losses = []
1363 self.adaptation_speeds = []
1364 self.uncertainties = []
1365 self.predictions = []
1366 self.gradients = []
1368 def update(self, predictions: torch.Tensor, targets: torch.Tensor,
1369 loss: Optional[float] = None, **kwargs):
1370 """Update metrics with new predictions and targets."""
1371 if self.config.compute_accuracy: 1371 ↛ 1375line 1371 didn't jump to line 1375 because the condition on line 1371 was always true
1372 accuracy = (predictions.argmax(dim=-1) == targets).float().mean().item()
1373 self.accuracies.append(accuracy)
1375 if self.config.compute_loss and loss is not None:
1376 self.losses.append(loss)
1378 if self.config.save_predictions:
1379 self.predictions.append(predictions.detach().cpu())
1381 # Add other metrics based on config
1382 for key, value in kwargs.items():
1383 if hasattr(self, key + 's'):
1384 getattr(self, key + 's').append(value)
1386 def compute_summary(self) -> Dict[str, float]:
1387 """Compute summary statistics."""
1388 summary = {}
1390 if self.accuracies:
1391 summary['mean_accuracy'] = np.mean(self.accuracies)
1392 summary['std_accuracy'] = np.std(self.accuracies)
1394 if self.losses:
1395 summary['mean_loss'] = np.mean(self.losses)
1396 summary['std_loss'] = np.std(self.losses)
1398 return summary
1401class StatisticalAnalysis:
1402 """Statistical analysis utilities for meta-learning research."""
1404 def __init__(self, config: StatsConfig):
1405 self.config = config
1407 def compute_confidence_interval(self, values: List[float]) -> Tuple[float, float, float]:
1408 """Compute confidence interval for given values."""
1409 return compute_confidence_interval(
1410 values,
1411 confidence_level=self.config.confidence_level,
1412 method="auto"
1413 )
1415 def statistical_test(self, group1: List[float], group2: List[float]) -> Dict[str, float]:
1416 """Perform statistical significance test between two groups."""
1417 from scipy import stats
1419 if self.config.significance_test == "t_test":
1420 statistic, p_value = stats.ttest_ind(group1, group2)
1421 elif self.config.significance_test == "mannwhitney":
1422 statistic, p_value = stats.mannwhitneyu(group1, group2, alternative='two-sided')
1423 else:
1424 raise ValueError(f"Unknown test: {self.config.significance_test}")
1426 return {
1427 'statistic': statistic,
1428 'p_value': p_value,
1429 'significant': p_value < (0.05 / self.config.confidence_level) # Bonferroni correction
1430 }
1433class CurriculumLearning:
1434 """Curriculum learning implementation for meta-learning."""
1436 def __init__(self, config: CurriculumConfig):
1437 self.config = config
1438 self.current_difficulty = config.initial_difficulty
1439 self.patience_counter = 0
1441 def update_difficulty(self, performance_metric: float) -> float:
1442 """Update curriculum difficulty based on performance."""
1443 if performance_metric >= self.config.difficulty_threshold:
1444 self.current_difficulty = min(
1445 1.0,
1446 self.current_difficulty + self.config.difficulty_increment
1447 )
1448 self.patience_counter = 0
1449 else:
1450 self.patience_counter += 1
1452 if self.patience_counter >= self.config.adaptation_patience:
1453 # Reduce difficulty if struggling
1454 self.current_difficulty = max(
1455 0.1,
1456 self.current_difficulty - self.config.difficulty_increment / 2
1457 )
1458 self.patience_counter = 0
1460 return self.current_difficulty
1462 def get_current_difficulty(self) -> float:
1463 """Get current curriculum difficulty level."""
1464 return self.current_difficulty
1467class TaskDiversityTracker:
1468 """Track diversity of meta-learning tasks."""
1470 def __init__(self, config: DiversityConfig):
1471 self.config = config
1472 self.task_features = []
1473 self.class_distributions = []
1475 def add_task(self, task_features: torch.Tensor, class_distribution: Optional[torch.Tensor] = None):
1476 """Add a new task for diversity tracking."""
1477 self.task_features.append(task_features.detach().cpu())
1479 if class_distribution is not None and self.config.track_class_distribution:
1480 self.class_distributions.append(class_distribution.detach().cpu())
1482 def compute_diversity(self) -> Dict[str, float]:
1483 """Compute task diversity metrics."""
1484 if not self.task_features:
1485 return {'diversity_score': 0.0}
1487 features = torch.stack(self.task_features)
1489 if self.config.diversity_metric == "cosine_similarity":
1490 # Compute pairwise cosine similarities
1491 normalized_features = F.normalize(features, dim=-1)
1492 similarities = torch.mm(normalized_features, normalized_features.t())
1494 # Average off-diagonal similarities (diversity = 1 - similarity)
1495 mask = ~torch.eye(similarities.size(0), dtype=bool)
1496 avg_similarity = similarities[mask].mean().item()
1497 diversity_score = 1.0 - avg_similarity
1499 else:
1500 diversity_score = 0.5 # Placeholder
1502 return {'diversity_score': diversity_score}
1505# =============================================================================
1506# Factory Functions - Required by __init__.py imports
1507# =============================================================================
1509def create_dataset(data: torch.Tensor, labels: torch.Tensor,
1510 task_config: TaskConfiguration,
1511 dataset_config: Optional[DatasetConfig] = None) -> MetaLearningDataset:
1512 """Factory function to create a meta-learning dataset."""
1513 if dataset_config is None:
1514 dataset_config = DatasetConfig()
1516 return MetaLearningDataset(data, labels, task_config)
1519def create_metrics_evaluator(config: Optional[MetricsConfig] = None) -> EvaluationMetrics:
1520 """Factory function to create an evaluation metrics instance."""
1521 if config is None:
1522 config = MetricsConfig()
1524 return EvaluationMetrics(config)
1527def create_curriculum_scheduler(config: Optional[CurriculumConfig] = None) -> CurriculumLearning:
1528 """Factory function to create a curriculum learning scheduler."""
1529 if config is None:
1530 config = CurriculumConfig()
1532 return CurriculumLearning(config)
1535def basic_confidence_interval(values: List[float], confidence_level: float = 0.95) -> Tuple[float, float, float]:
1536 """Basic confidence interval computation."""
1537 return compute_confidence_interval(values, confidence_level=confidence_level, method="t_test")
1540def estimate_difficulty(task_data: torch.Tensor, method: str = "entropy") -> float:
1541 """Estimate task difficulty using various methods."""
1542 if method == "entropy":
1543 # Simple entropy-based difficulty
1544 probs = F.softmax(task_data.mean(dim=0), dim=-1)
1545 entropy = -torch.sum(probs * torch.log(probs + 1e-8))
1546 return entropy.item() / np.log(task_data.size(-1)) # Normalized entropy
1547 else:
1548 return 0.5 # Default medium difficulty
1551def track_task_diversity(tasks: List[torch.Tensor], config: Optional[DiversityConfig] = None) -> Dict[str, float]:
1552 """Track diversity across multiple tasks."""
1553 if config is None:
1554 config = DiversityConfig()
1556 tracker = TaskDiversityTracker(config)
1558 for task in tasks:
1559 tracker.add_task(task.mean(dim=0)) # Use mean as task feature
1561 return tracker.compute_diversity()
1563# =============================================================================
1564# ENHANCED EVALUATION FUNCTIONS WITH CONFIGURATION SUPPORT
1565# =============================================================================
1567def evaluate_meta_learning_algorithm(
1568 algorithm,
1569 dataset: MetaLearningDataset,
1570 config: EvaluationConfig = None,
1571 num_episodes: int = None
1572) -> Dict[str, Any]:
1573 """
1574 Comprehensive evaluation of meta-learning algorithm with configurable methods.
1576 Args:
1577 algorithm: Meta-learning algorithm to evaluate
1578 dataset: MetaLearningDataset for evaluation
1579 config: EvaluationConfig for evaluation settings
1580 num_episodes: Number of evaluation episodes (overrides config)
1582 Returns:
1583 Dictionary with evaluation results and statistics
1584 """
1585 config = config or create_research_accurate_evaluation_config()
1586 num_episodes = num_episodes or config.num_episodes
1588 accuracies = []
1589 adaptation_curves = []
1591 logger.info(f"Starting evaluation with {num_episodes} episodes")
1593 for episode in range(num_episodes):
1594 # Sample task
1595 task = dataset.sample_task(task_idx=episode)
1597 # Evaluate algorithm on task
1598 result = algorithm.evaluate_task(
1599 task['support']['data'],
1600 task['support']['labels'],
1601 task['query']['data'],
1602 task['query']['labels'],
1603 return_adaptation_curve=config.track_adaptation_curve
1604 )
1606 accuracies.append(result['accuracy'])
1608 if config.track_adaptation_curve and 'adaptation_curve' in result:
1609 adaptation_curves.append(result['adaptation_curve'])
1611 # Compute statistics using configured CI method
1612 mean_accuracy, ci_lower, ci_upper = compute_confidence_interval_research_accurate(
1613 accuracies, config
1614 )
1616 results = {
1617 'mean_accuracy': mean_accuracy,
1618 'std_accuracy': np.std(accuracies),
1619 'ci_lower': ci_lower,
1620 'ci_upper': ci_upper,
1621 'all_accuracies': accuracies,
1622 'num_episodes': num_episodes,
1623 'ci_method_used': config.ci_method if not config.auto_method_selection
1624 else _auto_select_ci_method(accuracies, config)
1625 }
1627 if config.track_adaptation_curve and adaptation_curves:
1628 results['adaptation_curves'] = adaptation_curves
1629 results['mean_adaptation_curve'] = np.mean(adaptation_curves, axis=0).tolist()
1631 logger.info(f"Evaluation complete: {mean_accuracy:.4f} ± {ci_upper - mean_accuracy:.4f}")
1633 return results