Coverage for src/meta_learning/meta_learning_modules/utils_modules/statistical_evaluation.py: 0%

183 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2Statistical Evaluation Functions for Meta-Learning 

3================================================= 

4 

5Author: Benedict Chen (benedict@benedictchen.com) 

6 

7This module contains statistical functions for rigorous meta-learning evaluation, 

8including multiple confidence interval methods and research-accurate protocols. 

9""" 

10 

11import torch 

12import torch.nn.functional as F 

13from typing import Dict, List, Tuple, Optional, Any, Union 

14import numpy as np 

15import logging 

16from .configurations import EvaluationConfig, MetricsConfig, StatsConfig 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21def few_shot_accuracy( 

22 predictions: torch.Tensor, 

23 targets: torch.Tensor, 

24 return_per_class: bool = False 

25) -> Union[float, Tuple[float, torch.Tensor]]: 

26 """ 

27 Compute few-shot learning accuracy with advanced metrics. 

28  

29 Args: 

30 predictions: Model predictions [n_samples, n_classes] or [n_samples] 

31 targets: Ground truth labels [n_samples] 

32 return_per_class: Whether to return per-class accuracies 

33  

34 Returns: 

35 Overall accuracy, optionally with per-class accuracies 

36 """ 

37 if predictions.dim() == 2: 

38 # Logits or probabilities - take argmax 

39 pred_labels = predictions.argmax(dim=-1) 

40 else: 

41 # Already labels 

42 pred_labels = predictions 

43 

44 # Overall accuracy 

45 correct = (pred_labels == targets).float() 

46 overall_accuracy = correct.mean().item() 

47 

48 if return_per_class: 

49 # Per-class accuracy 

50 unique_classes = torch.unique(targets) 

51 per_class_accuracies = [] 

52 

53 for class_id in unique_classes: 

54 class_mask = targets == class_id 

55 if class_mask.sum() > 0: 

56 class_correct = correct[class_mask].mean().item() 

57 per_class_accuracies.append(class_correct) 

58 else: 

59 per_class_accuracies.append(0.0) 

60 

61 return overall_accuracy, torch.tensor(per_class_accuracies) 

62 

63 return overall_accuracy 

64 

65 

66def adaptation_speed( 

67 loss_curve: List[float], 

68 convergence_threshold: float = 0.01 

69) -> Tuple[int, float]: 

70 """ 

71 Measure adaptation speed for meta-learning algorithms. 

72  

73 Args: 

74 loss_curve: List of losses during adaptation steps 

75 convergence_threshold: Threshold for considering convergence 

76  

77 Returns: 

78 Tuple of (steps_to_convergence, final_loss) 

79 """ 

80 if len(loss_curve) < 2: 

81 return len(loss_curve), loss_curve[-1] if loss_curve else float('inf') 

82 

83 # Find convergence point 

84 for i in range(1, len(loss_curve)): 

85 loss_change = abs(loss_curve[i] - loss_curve[i-1]) 

86 if loss_change < convergence_threshold: 

87 return i + 1, loss_curve[i] 

88 

89 # No convergence found 

90 return len(loss_curve), loss_curve[-1] 

91 

92 

93def compute_confidence_interval( 

94 values: List[float], 

95 confidence_level: float = 0.95, 

96 num_bootstrap: int = 1000 

97) -> Tuple[float, float, float]: 

98 """ 

99 Compute confidence interval using bootstrap sampling. 

100  

101 FIXME RESEARCH ACCURACY ISSUES: 

102 1. BOOTSTRAP ONLY: Should also offer t-distribution CI for small samples (n < 30) 

103 2. MISSING VALIDATION: No check for minimum sample size for valid bootstrap 

104 3. NO BIAS CORRECTION: Should implement bias-corrected and accelerated (BCa) bootstrap 

105 4. MISSING STANDARD REPORTING: Meta-learning literature typically uses specific CI methods 

106  

107 CORRECT APPROACHES: 

108 - t-distribution CI for small samples 

109 - BCa bootstrap for better accuracy 

110 - Standard meta-learning evaluation protocols 

111  

112 Args: 

113 values: List of values to compute CI for 

114 confidence_level: Confidence level (e.g., 0.95 for 95%) 

115 num_bootstrap: Number of bootstrap samples 

116  

117 Returns: 

118 Tuple of (mean, lower_bound, upper_bound) 

119 """ 

120 if len(values) == 0: 

121 return 0.0, 0.0, 0.0 

122 

123 values = np.array(values) 

124 mean_val = np.mean(values) 

125 

126 # Check sample size and use appropriate method 

127 if len(values) < 30: 

128 # Use t-distribution CI for small samples (research-accurate) 

129 from scipy import stats 

130 t_critical = stats.t.ppf((1 + confidence_level) / 2, df=len(values) - 1) 

131 standard_error = np.std(values, ddof=1) / np.sqrt(len(values)) 

132 margin_of_error = t_critical * standard_error 

133 

134 ci_lower = mean_val - margin_of_error 

135 ci_upper = mean_val + margin_of_error 

136 

137 logger.debug(f"Used t-distribution CI for small sample (n={len(values)})") 

138 return mean_val, ci_lower, ci_upper 

139 

140 # CURRENT: Basic bootstrap (adequate but not optimal) 

141 bootstrap_means = [] 

142 for _ in range(num_bootstrap): 

143 bootstrap_sample = np.random.choice(values, size=len(values), replace=True) 

144 bootstrap_means.append(np.mean(bootstrap_sample)) 

145 

146 # Compute percentiles 

147 alpha = 1 - confidence_level 

148 lower_percentile = (alpha / 2) * 100 

149 upper_percentile = (1 - alpha / 2) * 100 

150 

151 lower_bound = np.percentile(bootstrap_means, lower_percentile) 

152 upper_bound = np.percentile(bootstrap_means, upper_percentile) 

153 

154 return mean_val, lower_bound, upper_bound 

155 

156 

157def compute_confidence_interval_research_accurate( 

158 values: List[float], 

159 config: EvaluationConfig = None, 

160 confidence_level: float = 0.95 

161) -> Tuple[float, float, float]: 

162 """ 

163 FIXME SOLUTION: Compute confidence interval using configured research-accurate method. 

164  

165 Uses appropriate CI method based on configuration and sample size with auto-selection. 

166 """ 

167 config = config or EvaluationConfig() 

168 

169 if not config.use_research_accurate_ci: 

170 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples) 

171 

172 # Auto-select method if enabled 

173 if config.auto_method_selection: 

174 method = _auto_select_ci_method(values, config) 

175 else: 

176 method = config.ci_method 

177 

178 # Route to appropriate method based on configuration 

179 if method == "t_distribution": 

180 return compute_t_confidence_interval(values, confidence_level) 

181 elif method == "meta_learning_standard": 

182 return compute_meta_learning_ci(values, confidence_level, config.num_episodes) 

183 elif method == "bca_bootstrap": 

184 return compute_bca_bootstrap_ci(values, confidence_level, config.num_bootstrap_samples) 

185 else: # bootstrap 

186 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples) 

187 

188 

189def _auto_select_ci_method(values: List[float], config: EvaluationConfig) -> str: 

190 """ 

191 Automatically select the best CI method based on data characteristics. 

192  

193 Selection criteria based on statistical best practices: 

194 - t-distribution for small samples (n < 30) 

195 - Bootstrap for moderate samples (30 <= n < 100)  

196 - BCa bootstrap for large samples (n >= 100) or skewed distributions 

197 - Meta-learning standard for exactly 600 episodes (standard protocol) 

198 """ 

199 n = len(values) 

200 

201 # Standard meta-learning evaluation protocol 

202 if n == config.num_episodes: 

203 return "meta_learning_standard" 

204 

205 # Small sample: use t-distribution 

206 if n < config.min_sample_size_for_bootstrap: 

207 return "t_distribution" 

208 

209 # Large sample or check for skewness 

210 if n >= 100: 

211 # Check for skewness (simple heuristic) 

212 values_array = np.array(values) 

213 mean_val = np.mean(values_array) 

214 median_val = np.median(values_array) 

215 

216 # If distribution is skewed, use BCa bootstrap 

217 skew_threshold = 0.1 * np.std(values_array) 

218 if abs(mean_val - median_val) > skew_threshold: 

219 return "bca_bootstrap" 

220 

221 # Default to standard bootstrap for moderate samples 

222 return "bootstrap" 

223 

224 

225def compute_t_confidence_interval( 

226 values: List[float], 

227 confidence_level: float = 0.95 

228) -> Tuple[float, float, float]: 

229 """ 

230 Compute confidence interval using t-distribution (appropriate for small samples). 

231  

232 Standard approach in meta-learning evaluation when n < 30. 

233 """ 

234 import scipy.stats as stats 

235 

236 if len(values) == 0: 

237 return 0.0, 0.0, 0.0 

238 

239 values = np.array(values) 

240 mean_val = np.mean(values) 

241 std_val = np.std(values, ddof=1) # Sample standard deviation 

242 n = len(values) 

243 

244 # Degrees of freedom 

245 df = n - 1 

246 

247 # Critical t-value 

248 alpha = 1 - confidence_level 

249 t_critical = stats.t.ppf(1 - alpha/2, df) 

250 

251 # Margin of error 

252 margin_error = t_critical * (std_val / np.sqrt(n)) 

253 

254 # Confidence interval 

255 lower_bound = mean_val - margin_error 

256 upper_bound = mean_val + margin_error 

257 

258 return mean_val, lower_bound, upper_bound 

259 

260 

261def compute_meta_learning_ci( 

262 accuracies: List[float], 

263 confidence_level: float = 0.95, 

264 num_episodes: int = 600 

265) -> Tuple[float, float, float]: 

266 """ 

267 Standard confidence interval computation for meta-learning evaluation. 

268  

269 Based on standard protocols from few-shot learning literature: 

270 - Vinyals et al. (2016): "Matching Networks"  

271 - Snell et al. (2017): "Prototypical Networks" 

272 - Finn et al. (2017): "MAML" 

273  

274 Typically uses 600 episodes with t-distribution CI. 

275 """ 

276 if len(accuracies) != num_episodes: 

277 print(f"Warning: Expected {num_episodes} episodes, got {len(accuracies)}") 

278 

279 # Use t-distribution for proper meta-learning evaluation 

280 return compute_t_confidence_interval(accuracies, confidence_level) 

281 

282 

283def compute_bca_bootstrap_ci( 

284 values: List[float], 

285 confidence_level: float = 0.95, 

286 num_bootstrap: int = 2000 

287) -> Tuple[float, float, float]: 

288 """ 

289 Bias-corrected and accelerated bootstrap confidence interval. 

290  

291 More accurate than basic bootstrap, especially for skewed distributions. 

292 Based on Efron & Tibshirani (1993) "An Introduction to the Bootstrap". 

293 """ 

294 import scipy.stats as stats 

295 

296 if len(values) == 0: 

297 return 0.0, 0.0, 0.0 

298 

299 values = np.array(values) 

300 n = len(values) 

301 mean_val = np.mean(values) 

302 

303 # Bootstrap resampling 

304 bootstrap_means = [] 

305 for _ in range(num_bootstrap): 

306 bootstrap_sample = np.random.choice(values, size=n, replace=True) 

307 bootstrap_means.append(np.mean(bootstrap_sample)) 

308 

309 bootstrap_means = np.array(bootstrap_means) 

310 

311 # Bias correction 

312 bias_correction = stats.norm.ppf((bootstrap_means < mean_val).mean()) 

313 

314 # Acceleration (jackknife) 

315 jackknife_means = [] 

316 for i in range(n): 

317 jackknife_sample = np.concatenate([values[:i], values[i+1:]]) 

318 jackknife_means.append(np.mean(jackknife_sample)) 

319 

320 jackknife_means = np.array(jackknife_means) 

321 jackknife_mean = np.mean(jackknife_means) 

322 

323 acceleration = np.sum((jackknife_mean - jackknife_means)**3) / \ 

324 (6 * (np.sum((jackknife_mean - jackknife_means)**2))**(3/2)) 

325 

326 # Adjusted percentiles 

327 alpha = 1 - confidence_level 

328 z_alpha_2 = stats.norm.ppf(alpha/2) 

329 z_1_alpha_2 = stats.norm.ppf(1 - alpha/2) 

330 

331 alpha_1 = stats.norm.cdf(bias_correction + 

332 (bias_correction + z_alpha_2) / (1 - acceleration * (bias_correction + z_alpha_2))) 

333 alpha_2 = stats.norm.cdf(bias_correction + 

334 (bias_correction + z_1_alpha_2) / (1 - acceleration * (bias_correction + z_1_alpha_2))) 

335 

336 # Compute bounds 

337 lower_bound = np.percentile(bootstrap_means, 100 * alpha_1) 

338 upper_bound = np.percentile(bootstrap_means, 100 * alpha_2) 

339 

340 return mean_val, lower_bound, upper_bound 

341 

342 

343def basic_confidence_interval(values: List[float], confidence_level: float = 0.95) -> Tuple[float, float, float]: 

344 """Basic confidence interval computation.""" 

345 return compute_confidence_interval(values, confidence_level=confidence_level) 

346 

347 

348def estimate_difficulty(task_data: torch.Tensor, method: str = "entropy") -> float: 

349 """Estimate task difficulty using various methods.""" 

350 if method == "entropy": 

351 # Simple entropy-based difficulty 

352 probs = F.softmax(task_data.mean(dim=0), dim=-1) 

353 entropy = -torch.sum(probs * torch.log(probs + 1e-8)) 

354 return entropy.item() / np.log(task_data.size(-1)) # Normalized entropy 

355 else: 

356 return 0.5 # Default medium difficulty 

357 

358 

359class EvaluationMetrics: 

360 """Comprehensive evaluation metrics for meta-learning algorithms.""" 

361 

362 def __init__(self, config: MetricsConfig): 

363 self.config = config 

364 self.reset() 

365 

366 def reset(self): 

367 """Reset all metrics to initial state.""" 

368 self.accuracies = [] 

369 self.losses = [] 

370 self.adaptation_speeds = [] 

371 self.uncertainties = [] 

372 self.predictions = [] 

373 self.gradients = [] 

374 

375 def update(self, predictions: torch.Tensor, targets: torch.Tensor, 

376 loss: Optional[float] = None, **kwargs): 

377 """Update metrics with new predictions and targets.""" 

378 if self.config.compute_accuracy: 

379 accuracy = (predictions.argmax(dim=-1) == targets).float().mean().item() 

380 self.accuracies.append(accuracy) 

381 

382 if self.config.compute_loss and loss is not None: 

383 self.losses.append(loss) 

384 

385 if self.config.save_predictions: 

386 self.predictions.append(predictions.detach().cpu()) 

387 

388 # Add other metrics based on config 

389 for key, value in kwargs.items(): 

390 if hasattr(self, key + 's'): 

391 getattr(self, key + 's').append(value) 

392 

393 def compute_summary(self) -> Dict[str, float]: 

394 """Compute summary statistics.""" 

395 summary = {} 

396 

397 if self.accuracies: 

398 summary['mean_accuracy'] = np.mean(self.accuracies) 

399 summary['std_accuracy'] = np.std(self.accuracies) 

400 

401 if self.losses: 

402 summary['mean_loss'] = np.mean(self.losses) 

403 summary['std_loss'] = np.std(self.losses) 

404 

405 return summary 

406 

407 

408class StatisticalAnalysis: 

409 """Statistical analysis utilities for meta-learning research.""" 

410 

411 def __init__(self, config: StatsConfig): 

412 self.config = config 

413 

414 def compute_confidence_interval(self, values: List[float]) -> Tuple[float, float, float]: 

415 """Compute confidence interval for given values.""" 

416 return compute_confidence_interval( 

417 values, 

418 confidence_level=self.config.confidence_level 

419 ) 

420 

421 def statistical_test(self, group1: List[float], group2: List[float]) -> Dict[str, float]: 

422 """Perform statistical significance test between two groups.""" 

423 from scipy import stats 

424 

425 if self.config.significance_test == "t_test": 

426 statistic, p_value = stats.ttest_ind(group1, group2) 

427 elif self.config.significance_test == "mannwhitney": 

428 statistic, p_value = stats.mannwhitneyu(group1, group2, alternative='two-sided') 

429 else: 

430 raise ValueError(f"Unknown test: {self.config.significance_test}") 

431 

432 return { 

433 'statistic': statistic, 

434 'p_value': p_value, 

435 'significant': p_value < (0.05 / self.config.confidence_level) # Bonferroni correction 

436 }