Coverage for src/meta_learning/meta_learning_modules/utils_modules/statistical_evaluation.py: 0%
183 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
1"""
2Statistical Evaluation Functions for Meta-Learning
3=================================================
5Author: Benedict Chen (benedict@benedictchen.com)
7This module contains statistical functions for rigorous meta-learning evaluation,
8including multiple confidence interval methods and research-accurate protocols.
9"""
11import torch
12import torch.nn.functional as F
13from typing import Dict, List, Tuple, Optional, Any, Union
14import numpy as np
15import logging
16from .configurations import EvaluationConfig, MetricsConfig, StatsConfig
18logger = logging.getLogger(__name__)
21def few_shot_accuracy(
22 predictions: torch.Tensor,
23 targets: torch.Tensor,
24 return_per_class: bool = False
25) -> Union[float, Tuple[float, torch.Tensor]]:
26 """
27 Compute few-shot learning accuracy with advanced metrics.
29 Args:
30 predictions: Model predictions [n_samples, n_classes] or [n_samples]
31 targets: Ground truth labels [n_samples]
32 return_per_class: Whether to return per-class accuracies
34 Returns:
35 Overall accuracy, optionally with per-class accuracies
36 """
37 if predictions.dim() == 2:
38 # Logits or probabilities - take argmax
39 pred_labels = predictions.argmax(dim=-1)
40 else:
41 # Already labels
42 pred_labels = predictions
44 # Overall accuracy
45 correct = (pred_labels == targets).float()
46 overall_accuracy = correct.mean().item()
48 if return_per_class:
49 # Per-class accuracy
50 unique_classes = torch.unique(targets)
51 per_class_accuracies = []
53 for class_id in unique_classes:
54 class_mask = targets == class_id
55 if class_mask.sum() > 0:
56 class_correct = correct[class_mask].mean().item()
57 per_class_accuracies.append(class_correct)
58 else:
59 per_class_accuracies.append(0.0)
61 return overall_accuracy, torch.tensor(per_class_accuracies)
63 return overall_accuracy
66def adaptation_speed(
67 loss_curve: List[float],
68 convergence_threshold: float = 0.01
69) -> Tuple[int, float]:
70 """
71 Measure adaptation speed for meta-learning algorithms.
73 Args:
74 loss_curve: List of losses during adaptation steps
75 convergence_threshold: Threshold for considering convergence
77 Returns:
78 Tuple of (steps_to_convergence, final_loss)
79 """
80 if len(loss_curve) < 2:
81 return len(loss_curve), loss_curve[-1] if loss_curve else float('inf')
83 # Find convergence point
84 for i in range(1, len(loss_curve)):
85 loss_change = abs(loss_curve[i] - loss_curve[i-1])
86 if loss_change < convergence_threshold:
87 return i + 1, loss_curve[i]
89 # No convergence found
90 return len(loss_curve), loss_curve[-1]
93def compute_confidence_interval(
94 values: List[float],
95 confidence_level: float = 0.95,
96 num_bootstrap: int = 1000
97) -> Tuple[float, float, float]:
98 """
99 Compute confidence interval using bootstrap sampling.
101 FIXME RESEARCH ACCURACY ISSUES:
102 1. BOOTSTRAP ONLY: Should also offer t-distribution CI for small samples (n < 30)
103 2. MISSING VALIDATION: No check for minimum sample size for valid bootstrap
104 3. NO BIAS CORRECTION: Should implement bias-corrected and accelerated (BCa) bootstrap
105 4. MISSING STANDARD REPORTING: Meta-learning literature typically uses specific CI methods
107 CORRECT APPROACHES:
108 - t-distribution CI for small samples
109 - BCa bootstrap for better accuracy
110 - Standard meta-learning evaluation protocols
112 Args:
113 values: List of values to compute CI for
114 confidence_level: Confidence level (e.g., 0.95 for 95%)
115 num_bootstrap: Number of bootstrap samples
117 Returns:
118 Tuple of (mean, lower_bound, upper_bound)
119 """
120 if len(values) == 0:
121 return 0.0, 0.0, 0.0
123 values = np.array(values)
124 mean_val = np.mean(values)
126 # Check sample size and use appropriate method
127 if len(values) < 30:
128 # Use t-distribution CI for small samples (research-accurate)
129 from scipy import stats
130 t_critical = stats.t.ppf((1 + confidence_level) / 2, df=len(values) - 1)
131 standard_error = np.std(values, ddof=1) / np.sqrt(len(values))
132 margin_of_error = t_critical * standard_error
134 ci_lower = mean_val - margin_of_error
135 ci_upper = mean_val + margin_of_error
137 logger.debug(f"Used t-distribution CI for small sample (n={len(values)})")
138 return mean_val, ci_lower, ci_upper
140 # CURRENT: Basic bootstrap (adequate but not optimal)
141 bootstrap_means = []
142 for _ in range(num_bootstrap):
143 bootstrap_sample = np.random.choice(values, size=len(values), replace=True)
144 bootstrap_means.append(np.mean(bootstrap_sample))
146 # Compute percentiles
147 alpha = 1 - confidence_level
148 lower_percentile = (alpha / 2) * 100
149 upper_percentile = (1 - alpha / 2) * 100
151 lower_bound = np.percentile(bootstrap_means, lower_percentile)
152 upper_bound = np.percentile(bootstrap_means, upper_percentile)
154 return mean_val, lower_bound, upper_bound
157def compute_confidence_interval_research_accurate(
158 values: List[float],
159 config: EvaluationConfig = None,
160 confidence_level: float = 0.95
161) -> Tuple[float, float, float]:
162 """
163 FIXME SOLUTION: Compute confidence interval using configured research-accurate method.
165 Uses appropriate CI method based on configuration and sample size with auto-selection.
166 """
167 config = config or EvaluationConfig()
169 if not config.use_research_accurate_ci:
170 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples)
172 # Auto-select method if enabled
173 if config.auto_method_selection:
174 method = _auto_select_ci_method(values, config)
175 else:
176 method = config.ci_method
178 # Route to appropriate method based on configuration
179 if method == "t_distribution":
180 return compute_t_confidence_interval(values, confidence_level)
181 elif method == "meta_learning_standard":
182 return compute_meta_learning_ci(values, confidence_level, config.num_episodes)
183 elif method == "bca_bootstrap":
184 return compute_bca_bootstrap_ci(values, confidence_level, config.num_bootstrap_samples)
185 else: # bootstrap
186 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples)
189def _auto_select_ci_method(values: List[float], config: EvaluationConfig) -> str:
190 """
191 Automatically select the best CI method based on data characteristics.
193 Selection criteria based on statistical best practices:
194 - t-distribution for small samples (n < 30)
195 - Bootstrap for moderate samples (30 <= n < 100)
196 - BCa bootstrap for large samples (n >= 100) or skewed distributions
197 - Meta-learning standard for exactly 600 episodes (standard protocol)
198 """
199 n = len(values)
201 # Standard meta-learning evaluation protocol
202 if n == config.num_episodes:
203 return "meta_learning_standard"
205 # Small sample: use t-distribution
206 if n < config.min_sample_size_for_bootstrap:
207 return "t_distribution"
209 # Large sample or check for skewness
210 if n >= 100:
211 # Check for skewness (simple heuristic)
212 values_array = np.array(values)
213 mean_val = np.mean(values_array)
214 median_val = np.median(values_array)
216 # If distribution is skewed, use BCa bootstrap
217 skew_threshold = 0.1 * np.std(values_array)
218 if abs(mean_val - median_val) > skew_threshold:
219 return "bca_bootstrap"
221 # Default to standard bootstrap for moderate samples
222 return "bootstrap"
225def compute_t_confidence_interval(
226 values: List[float],
227 confidence_level: float = 0.95
228) -> Tuple[float, float, float]:
229 """
230 Compute confidence interval using t-distribution (appropriate for small samples).
232 Standard approach in meta-learning evaluation when n < 30.
233 """
234 import scipy.stats as stats
236 if len(values) == 0:
237 return 0.0, 0.0, 0.0
239 values = np.array(values)
240 mean_val = np.mean(values)
241 std_val = np.std(values, ddof=1) # Sample standard deviation
242 n = len(values)
244 # Degrees of freedom
245 df = n - 1
247 # Critical t-value
248 alpha = 1 - confidence_level
249 t_critical = stats.t.ppf(1 - alpha/2, df)
251 # Margin of error
252 margin_error = t_critical * (std_val / np.sqrt(n))
254 # Confidence interval
255 lower_bound = mean_val - margin_error
256 upper_bound = mean_val + margin_error
258 return mean_val, lower_bound, upper_bound
261def compute_meta_learning_ci(
262 accuracies: List[float],
263 confidence_level: float = 0.95,
264 num_episodes: int = 600
265) -> Tuple[float, float, float]:
266 """
267 Standard confidence interval computation for meta-learning evaluation.
269 Based on standard protocols from few-shot learning literature:
270 - Vinyals et al. (2016): "Matching Networks"
271 - Snell et al. (2017): "Prototypical Networks"
272 - Finn et al. (2017): "MAML"
274 Typically uses 600 episodes with t-distribution CI.
275 """
276 if len(accuracies) != num_episodes:
277 print(f"Warning: Expected {num_episodes} episodes, got {len(accuracies)}")
279 # Use t-distribution for proper meta-learning evaluation
280 return compute_t_confidence_interval(accuracies, confidence_level)
283def compute_bca_bootstrap_ci(
284 values: List[float],
285 confidence_level: float = 0.95,
286 num_bootstrap: int = 2000
287) -> Tuple[float, float, float]:
288 """
289 Bias-corrected and accelerated bootstrap confidence interval.
291 More accurate than basic bootstrap, especially for skewed distributions.
292 Based on Efron & Tibshirani (1993) "An Introduction to the Bootstrap".
293 """
294 import scipy.stats as stats
296 if len(values) == 0:
297 return 0.0, 0.0, 0.0
299 values = np.array(values)
300 n = len(values)
301 mean_val = np.mean(values)
303 # Bootstrap resampling
304 bootstrap_means = []
305 for _ in range(num_bootstrap):
306 bootstrap_sample = np.random.choice(values, size=n, replace=True)
307 bootstrap_means.append(np.mean(bootstrap_sample))
309 bootstrap_means = np.array(bootstrap_means)
311 # Bias correction
312 bias_correction = stats.norm.ppf((bootstrap_means < mean_val).mean())
314 # Acceleration (jackknife)
315 jackknife_means = []
316 for i in range(n):
317 jackknife_sample = np.concatenate([values[:i], values[i+1:]])
318 jackknife_means.append(np.mean(jackknife_sample))
320 jackknife_means = np.array(jackknife_means)
321 jackknife_mean = np.mean(jackknife_means)
323 acceleration = np.sum((jackknife_mean - jackknife_means)**3) / \
324 (6 * (np.sum((jackknife_mean - jackknife_means)**2))**(3/2))
326 # Adjusted percentiles
327 alpha = 1 - confidence_level
328 z_alpha_2 = stats.norm.ppf(alpha/2)
329 z_1_alpha_2 = stats.norm.ppf(1 - alpha/2)
331 alpha_1 = stats.norm.cdf(bias_correction +
332 (bias_correction + z_alpha_2) / (1 - acceleration * (bias_correction + z_alpha_2)))
333 alpha_2 = stats.norm.cdf(bias_correction +
334 (bias_correction + z_1_alpha_2) / (1 - acceleration * (bias_correction + z_1_alpha_2)))
336 # Compute bounds
337 lower_bound = np.percentile(bootstrap_means, 100 * alpha_1)
338 upper_bound = np.percentile(bootstrap_means, 100 * alpha_2)
340 return mean_val, lower_bound, upper_bound
343def basic_confidence_interval(values: List[float], confidence_level: float = 0.95) -> Tuple[float, float, float]:
344 """Basic confidence interval computation."""
345 return compute_confidence_interval(values, confidence_level=confidence_level)
348def estimate_difficulty(task_data: torch.Tensor, method: str = "entropy") -> float:
349 """Estimate task difficulty using various methods."""
350 if method == "entropy":
351 # Simple entropy-based difficulty
352 probs = F.softmax(task_data.mean(dim=0), dim=-1)
353 entropy = -torch.sum(probs * torch.log(probs + 1e-8))
354 return entropy.item() / np.log(task_data.size(-1)) # Normalized entropy
355 else:
356 return 0.5 # Default medium difficulty
359class EvaluationMetrics:
360 """Comprehensive evaluation metrics for meta-learning algorithms."""
362 def __init__(self, config: MetricsConfig):
363 self.config = config
364 self.reset()
366 def reset(self):
367 """Reset all metrics to initial state."""
368 self.accuracies = []
369 self.losses = []
370 self.adaptation_speeds = []
371 self.uncertainties = []
372 self.predictions = []
373 self.gradients = []
375 def update(self, predictions: torch.Tensor, targets: torch.Tensor,
376 loss: Optional[float] = None, **kwargs):
377 """Update metrics with new predictions and targets."""
378 if self.config.compute_accuracy:
379 accuracy = (predictions.argmax(dim=-1) == targets).float().mean().item()
380 self.accuracies.append(accuracy)
382 if self.config.compute_loss and loss is not None:
383 self.losses.append(loss)
385 if self.config.save_predictions:
386 self.predictions.append(predictions.detach().cpu())
388 # Add other metrics based on config
389 for key, value in kwargs.items():
390 if hasattr(self, key + 's'):
391 getattr(self, key + 's').append(value)
393 def compute_summary(self) -> Dict[str, float]:
394 """Compute summary statistics."""
395 summary = {}
397 if self.accuracies:
398 summary['mean_accuracy'] = np.mean(self.accuracies)
399 summary['std_accuracy'] = np.std(self.accuracies)
401 if self.losses:
402 summary['mean_loss'] = np.mean(self.losses)
403 summary['std_loss'] = np.std(self.losses)
405 return summary
408class StatisticalAnalysis:
409 """Statistical analysis utilities for meta-learning research."""
411 def __init__(self, config: StatsConfig):
412 self.config = config
414 def compute_confidence_interval(self, values: List[float]) -> Tuple[float, float, float]:
415 """Compute confidence interval for given values."""
416 return compute_confidence_interval(
417 values,
418 confidence_level=self.config.confidence_level
419 )
421 def statistical_test(self, group1: List[float], group2: List[float]) -> Dict[str, float]:
422 """Perform statistical significance test between two groups."""
423 from scipy import stats
425 if self.config.significance_test == "t_test":
426 statistic, p_value = stats.ttest_ind(group1, group2)
427 elif self.config.significance_test == "mannwhitney":
428 statistic, p_value = stats.mannwhitneyu(group1, group2, alternative='two-sided')
429 else:
430 raise ValueError(f"Unknown test: {self.config.significance_test}")
432 return {
433 'statistic': statistic,
434 'p_value': p_value,
435 'significant': p_value < (0.05 / self.config.confidence_level) # Bonferroni correction
436 }