Coverage for src/meta_learning/meta_learning_modules/utils.py: 21%

649 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2🧰 Meta-Learning Utilities - Research-Grade Helper Functions 

3=========================================================== 

4 

5Author: Benedict Chen (benedict@benedictchen.com) 

6 

7💰 Donations: Help support this research! 

8 PayPal: https://www.paypal.com/cgi-bin/webscr?cmd=_s-xclick&hosted_button_id=WXQKYYKPHWXHS 

9 💖 Please consider recurring donations to support continued meta-learning research 

10 

11This module provides research-accurate utilities for meta-learning that fill 

12critical gaps in existing libraries (learn2learn, torchmeta, higher) and 

13provide statistically rigorous functionality for proper scientific evaluation. 

14 

15🔬 Research Foundation: 

16====================== 

17Implements utilities supporting core meta-learning research: 

18- Hospedales et al. (2021): Meta-learning statistical evaluation protocols 

19- Chen et al. (2019): Closer look at few-shot classification benchmarking 

20- Triantafillou et al. (2020): Meta-Dataset evaluation methodology 

21- Gidaris & Komodakis (2019): Dynamic few-shot visual classification 

22 

23🎯 Key Utility Categories: 

24========================= 

251. **Dataset & Task Sampling**: Research-accurate task generation with difficulty control 

262. **Statistical Evaluation**: Proper confidence intervals following meta-learning protocols 

273. **Benchmarking Tools**: Fair comparison methodology across algorithms 

284. **Data Augmentation**: Meta-learning specific augmentation strategies 

295. **Analysis & Visualization**: Research-grade plots and statistical analysis 

30 

31ELI5 Explanation: 

32================ 

33Think of this module like a Swiss Army knife for meta-learning research! 🔧 

34 

35Just like a Swiss Army knife has all the small tools you need for camping 

36(bottle opener, small knife, screwdriver), this module has all the small 

37but essential tools you need for meta-learning research: 

38 

39🎲 **Task Generators**: Create fair "learning challenges" for your algorithms 

40📊 **Statistical Tools**: Make sure your results are scientifically reliable  

41📈 **Benchmarking**: Compare algorithms fairly (like timing runners on the same track) 

42🔍 **Analysis Tools**: Understand what your algorithms are actually learning 

43 

44Without these utilities, doing meta-learning research would be like trying 

45to fix a watch with just a hammer - you need the right specialized tools! 

46 

47ASCII Utility Architecture: 

48=========================== 

49 Raw Data Task Generator Meta-Learning 

50 ┌─────────┐ ┌─────────────┐ Episodes 

51 │ Images │────▶│ Sample N-way│────▶┌─────────────┐ 

52 │ Labels │ │ K-shot tasks│ │Support: 5x5 │ 

53 └─────────┘ └─────────────┘ │Query: 5x15 │ 

54 │ │ └─────────────┘ 

55 │ ▼ │ 

56 │ ┌─────────────┐ ▼ 

57 └────────│Statistical │ ┌─────────────┐ 

58 │Analyzer │◀──────│Algorithm │ 

59 │- CI calc │ │Performance │ 

60 │- Significance│ │Metrics │ 

61 └─────────────┘ └─────────────┘ 

62 │ │ 

63 ▼ ▼ 

64 ┌─────────────┐ ┌─────────────┐ 

65 │Research │ │Visualization│ 

66 │Report │◀──────│& Analysis │ 

67 │Generator │ │Tools │ 

68 └─────────────┘ └─────────────┘ 

69 

70⚡ Core Components: 

71================== 

721. **MetaLearningDataset**: Generates episodic tasks with proper statistics 

732. **TaskConfiguration**: Controls N-way K-shot sampling with difficulty metrics 

743. **EvaluationConfig**: Statistical evaluation following research protocols 

754. **ConfidenceIntervals**: Research-accurate CI computation (4 methods available) 

765. **BenchmarkSuite**: Fair algorithm comparison with statistical rigor 

77 

78📊 Statistical Rigor Features: 

79============================= 

80• **Multiple CI Methods**: Bootstrap, t-distribution, BCa bootstrap, meta-learning standard 

81• **Proper Episode Sampling**: Stratified sampling preserving class distributions 

82• **Difficulty Estimation**: 4 methods (silhouette, entropy, KNN, pairwise distance) 

83• **Statistical Testing**: Significance tests between algorithm performances 

84• **Research Protocols**: 600-episode evaluation following Hospedales et al. (2021) 

85 

86This module transforms ad-hoc meta-learning experiments into rigorous, 

87reproducible scientific research with proper statistical foundations. 

88""" 

89 

90import torch 

91import torch.nn as nn 

92import torch.nn.functional as F 

93from torch.utils.data import Dataset, DataLoader, Sampler 

94from typing import Dict, List, Tuple, Optional, Any, Iterator, Union, Callable 

95import numpy as np 

96import random 

97import logging 

98from collections import defaultdict, Counter 

99import matplotlib.pyplot as plt 

100import seaborn as sns 

101from dataclasses import dataclass 

102import json 

103import pickle 

104from pathlib import Path 

105 

106logger = logging.getLogger(__name__) 

107 

108 

109@dataclass 

110class TaskConfiguration: 

111 """Configuration for meta-learning tasks.""" 

112 n_way: int = 5 

113 k_shot: int = 5 

114 q_query: int = 15 

115 num_tasks: int = 1000 

116 task_type: str = "classification" 

117 augmentation_strategy: str = "basic" # basic, advanced, none 

118 

119 # FIXME SOLUTION: Configuration options for difficulty estimation methods 

120 difficulty_estimation_method: str = "pairwise_distance" # "pairwise_distance", "silhouette", "entropy", "knn" 

121 use_research_accurate_difficulty: bool = False # Enable research-backed methods 

122 

123 

124@dataclass 

125class EvaluationConfig: 

126 """Configuration for meta-learning evaluation.""" 

127 confidence_intervals: bool = True 

128 num_bootstrap_samples: int = 1000 

129 significance_level: float = 0.05 

130 track_adaptation_curve: bool = True 

131 compute_uncertainty: bool = True 

132 

133 # FIXME SOLUTION: Configuration options for confidence interval methods 

134 ci_method: str = "bootstrap" # "bootstrap", "t_distribution", "meta_learning_standard", "bca_bootstrap" 

135 use_research_accurate_ci: bool = False # Enable research-backed CI methods 

136 num_episodes: int = 600 # Standard meta-learning evaluation protocol 

137 

138 # Additional configuration for advanced CI methods 

139 min_sample_size_for_bootstrap: int = 30 # Minimum sample size for bootstrap vs t-distribution 

140 auto_method_selection: bool = True # Automatically select best CI method based on data 

141 

142 

143class MetaLearningDataset(Dataset): 

144 """ 

145 Advanced Meta-Learning Dataset with sophisticated task sampling. 

146  

147 Key improvements over existing libraries: 

148 1. Hierarchical task organization with difficulty levels 

149 2. Balanced task sampling across domains and difficulties 

150 3. Dynamic task generation with curriculum learning 

151 4. Advanced data augmentation strategies for meta-learning 

152 5. Task similarity tracking and diverse sampling 

153 """ 

154 

155 def __init__( 

156 self, 

157 data: torch.Tensor, 

158 labels: torch.Tensor, 

159 task_config: TaskConfiguration = None, 

160 class_names: Optional[List[str]] = None, 

161 domain_labels: Optional[torch.Tensor] = None 

162 ): 

163 """ 

164 Initialize Meta-Learning Dataset. 

165  

166 Args: 

167 data: Input data [n_samples, ...] 

168 labels: Class labels [n_samples] 

169 task_config: Task configuration 

170 class_names: Optional class names for interpretability 

171 domain_labels: Optional domain labels for cross-domain tasks 

172 """ 

173 self.data = data 

174 self.labels = labels 

175 self.config = task_config or TaskConfiguration() 

176 self.class_names = class_names 

177 self.domain_labels = domain_labels 

178 

179 # Organize data by class for efficient sampling 

180 self.class_to_indices = defaultdict(list) 

181 for idx, label in enumerate(labels): 

182 self.class_to_indices[label.item()].append(idx) 

183 

184 self.unique_classes = list(self.class_to_indices.keys()) 

185 self.num_classes = len(self.unique_classes) 

186 

187 # Task history for diversity tracking 

188 self.task_history = [] 

189 self.class_usage_count = Counter() 

190 

191 # Difficulty estimation using configured method 

192 if self.config.use_research_accurate_difficulty: 

193 self.class_difficulties = self._estimate_class_difficulties_research_accurate() 

194 else: 

195 self.class_difficulties = self._estimate_class_difficulties() 

196 

197 logger.info(f"Initialized MetaLearningDataset: {self.num_classes} classes, {len(data)} samples") 

198 

199 def __len__(self) -> int: 

200 """Number of possible tasks (virtually infinite for meta-learning).""" 

201 return self.config.num_tasks 

202 

203 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 

204 """ 

205 Sample a meta-learning task. 

206  

207 Returns: 

208 Dictionary containing support and query sets with labels 

209 """ 

210 task = self.sample_task(task_idx=idx) 

211 return task 

212 

213 def sample_task( 

214 self, 

215 task_idx: Optional[int] = None, 

216 specified_classes: Optional[List[int]] = None, 

217 difficulty_level: Optional[str] = None 

218 ) -> Dict[str, torch.Tensor]: 

219 """ 

220 Sample a single meta-learning task with advanced strategies. 

221  

222 Args: 

223 task_idx: Optional task index for reproducibility 

224 specified_classes: Specific classes to use (overrides sampling) 

225 difficulty_level: "easy", "medium", "hard", or None for automatic 

226  

227 Returns: 

228 Task dictionary with support/query sets and metadata 

229 """ 

230 # Set random seed for reproducible task sampling 

231 if task_idx is not None: 

232 torch.manual_seed(42 + task_idx) 

233 np.random.seed(42 + task_idx) 

234 

235 # Select classes for this task 

236 if specified_classes: 

237 task_classes = specified_classes 

238 else: 

239 task_classes = self._sample_task_classes(difficulty_level) 

240 

241 # Sample support and query sets 

242 support_data, support_labels, query_data, query_labels = self._sample_support_query( 

243 task_classes 

244 ) 

245 

246 # Apply data augmentation 

247 if self.config.augmentation_strategy != "none": 

248 support_data = self._apply_augmentation(support_data, self.config.augmentation_strategy) 

249 

250 # Update task history and class usage 

251 self.task_history.append(task_classes) 

252 for class_id in task_classes: 

253 self.class_usage_count[class_id] += 1 

254 

255 # Compute task metadata 

256 task_metadata = self._compute_task_metadata(task_classes, support_labels, query_labels) 

257 

258 return { 

259 "support": { 

260 "data": support_data, 

261 "labels": support_labels 

262 }, 

263 "query": { 

264 "data": query_data, 

265 "labels": query_labels 

266 }, 

267 "task_classes": torch.tensor(task_classes), 

268 "metadata": task_metadata 

269 } 

270 

271 def _sample_task_classes(self, difficulty_level: Optional[str] = None) -> List[int]: 

272 """Sample classes for a task with diversity and difficulty control.""" 

273 if difficulty_level: 

274 # Filter classes by difficulty 

275 if difficulty_level == "easy": 

276 candidate_classes = [c for c in self.unique_classes 

277 if self.class_difficulties[c] < 0.3] 

278 elif difficulty_level == "medium": 

279 candidate_classes = [c for c in self.unique_classes 

280 if 0.3 <= self.class_difficulties[c] < 0.7] 

281 elif difficulty_level == "hard": 

282 candidate_classes = [c for c in self.unique_classes 

283 if self.class_difficulties[c] >= 0.7] 

284 else: 

285 candidate_classes = self.unique_classes 

286 else: 

287 candidate_classes = self.unique_classes 

288 

289 # Ensure we have enough classes 

290 if len(candidate_classes) < self.config.n_way: 

291 candidate_classes = self.unique_classes 

292 

293 # Diversity-aware sampling (prefer less used classes) 

294 class_weights = [] 

295 for class_id in candidate_classes: 

296 # Inverse frequency weighting for diversity 

297 usage_count = self.class_usage_count.get(class_id, 0) 

298 weight = 1.0 / (1.0 + usage_count) 

299 class_weights.append(weight) 

300 

301 # Normalize weights 

302 class_weights = np.array(class_weights) 

303 class_weights = class_weights / class_weights.sum() 

304 

305 # Sample classes 

306 selected_indices = np.random.choice( 

307 len(candidate_classes), 

308 size=self.config.n_way, 

309 replace=False, 

310 p=class_weights 

311 ) 

312 

313 return [candidate_classes[i] for i in selected_indices] 

314 

315 def _sample_support_query( 

316 self, 

317 task_classes: List[int] 

318 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

319 """Sample support and query sets for given classes.""" 

320 support_data = [] 

321 support_labels = [] 

322 query_data = [] 

323 query_labels = [] 

324 

325 for new_label, original_class in enumerate(task_classes): 

326 # Get indices for this class 

327 class_indices = self.class_to_indices[original_class] 

328 

329 # Ensure we have enough samples 

330 total_needed = self.config.k_shot + self.config.q_query 

331 if len(class_indices) < total_needed: 

332 # Sample with replacement if necessary 

333 selected_indices = np.random.choice( 

334 class_indices, size=total_needed, replace=True 

335 ) 

336 else: 

337 selected_indices = np.random.choice( 

338 class_indices, size=total_needed, replace=False 

339 ) 

340 

341 # Split into support and query 

342 support_indices = selected_indices[:self.config.k_shot] 

343 query_indices = selected_indices[self.config.k_shot:] 

344 

345 # Collect support set 

346 for idx in support_indices: 

347 support_data.append(self.data[idx]) 

348 support_labels.append(new_label) 

349 

350 # Collect query set 

351 for idx in query_indices: 

352 query_data.append(self.data[idx]) 

353 query_labels.append(new_label) 

354 

355 return ( 

356 torch.stack(support_data), 

357 torch.tensor(support_labels), 

358 torch.stack(query_data), 

359 torch.tensor(query_labels) 

360 ) 

361 

362 def _estimate_class_difficulties(self) -> Dict[int, float]: 

363 """ 

364 Estimate difficulty of each class based on intra-class variance. 

365  

366 FIXME RESEARCH ACCURACY ISSUES: 

367 1. ARBITRARY DIFFICULTY METRIC: No research basis for using mean pairwise distance as difficulty 

368 2. INEFFICIENT COMPUTATION: O(n²) complexity for pairwise distance calculation 

369 3. MISSING ESTABLISHED METRICS: Should use research-validated difficulty measures 

370 4. NO COMPARISON TO BASELINES: Not comparing to standard difficulty estimation methods 

371  

372 BETTER APPROACHES from research: 

373 """ 

374 difficulties = {} 

375 

376 for class_id, indices in self.class_to_indices.items(): 

377 if len(indices) > 1: 

378 class_data = self.data[indices] 

379 

380 # CURRENT (PROBLEMATIC): Arbitrary pairwise distance measure 

381 flattened_data = class_data.view(len(class_data), -1) 

382 distances = torch.cdist(flattened_data, flattened_data) 

383 mean_distance = distances.sum() / (len(distances) ** 2 - len(distances)) 

384 difficulties[class_id] = mean_distance.item() 

385 else: 

386 difficulties[class_id] = 0.5 # Default medium difficulty 

387 

388 # Normalize difficulties to [0, 1] 

389 if difficulties: 

390 max_diff = max(difficulties.values()) 

391 min_diff = min(difficulties.values()) 

392 if max_diff > min_diff: 

393 for class_id in difficulties: 

394 difficulties[class_id] = (difficulties[class_id] - min_diff) / (max_diff - min_diff) 

395 

396 return difficulties 

397 

398 def _estimate_class_difficulties_research_accurate(self) -> Dict[int, float]: 

399 """ 

400 Route to appropriate research-accurate difficulty estimation method based on configuration. 

401 """ 

402 if self.config.difficulty_estimation_method == "silhouette": 

403 return self._estimate_class_difficulty_silhouette() 

404 elif self.config.difficulty_estimation_method == "entropy": 

405 return self._estimate_class_difficulty_entropy() 

406 elif self.config.difficulty_estimation_method == "knn": 

407 return self._estimate_class_difficulty_knn() 

408 else: # default to pairwise_distance 

409 return self._estimate_class_difficulties() 

410 

411 def _estimate_class_difficulty_silhouette(self) -> Dict[int, float]: 

412 """ 

413 FIXME SOLUTION 1: Use Silhouette Score for class difficulty estimation. 

414  

415 Based on "Silhouette: a graphical aid to the interpretation and validation of cluster analysis" (1987) 

416 Silhouette score measures how well-separated classes are. 

417 """ 

418 from sklearn.metrics import silhouette_samples 

419 

420 difficulties = {} 

421 all_data = self.data.view(len(self.data), -1).numpy() 

422 all_labels = self.labels.numpy() 

423 

424 # Compute silhouette scores for all samples 

425 silhouette_scores = silhouette_samples(all_data, all_labels) 

426 

427 # Average silhouette score per class (lower = more difficult) 

428 for class_id in self.unique_classes: 

429 class_mask = all_labels == class_id 

430 class_silhouette = silhouette_scores[class_mask].mean() 

431 

432 # Convert to difficulty (1 - silhouette, normalized to [0, 1]) 

433 difficulties[class_id] = 1.0 - (class_silhouette + 1.0) / 2.0 

434 

435 return difficulties 

436 

437 def _estimate_class_difficulty_entropy(self) -> Dict[int, float]: 

438 """ 

439 FIXME SOLUTION 2: Use feature entropy for difficulty estimation. 

440  

441 Classes with higher feature entropy are typically more difficult. 

442 Common approach in few-shot learning literature. 

443 """ 

444 difficulties = {} 

445 

446 for class_id, indices in self.class_to_indices.items(): 

447 if len(indices) > 1: 

448 class_data = self.data[indices] 

449 

450 # Compute feature-wise entropy 

451 flattened_data = class_data.view(len(class_data), -1) 

452 

453 # Discretize features for entropy calculation 

454 discretized = torch.floor(flattened_data * 10) / 10 # Simple binning 

455 

456 # Compute entropy for each feature dimension 

457 entropies = [] 

458 for feature_dim in range(discretized.shape[1]): 

459 feature_values = discretized[:, feature_dim] 

460 unique_vals, counts = torch.unique(feature_values, return_counts=True) 

461 probs = counts.float() / len(feature_values) 

462 entropy = -torch.sum(probs * torch.log(probs + 1e-8)) 

463 entropies.append(entropy.item()) 

464 

465 # Average entropy as difficulty measure 

466 difficulties[class_id] = np.mean(entropies) 

467 else: 

468 difficulties[class_id] = 0.5 

469 

470 return difficulties 

471 

472 def _estimate_class_difficulty_knn(self) -> Dict[int, float]: 

473 """ 

474 FIXME SOLUTION 3: Use k-NN classification accuracy for difficulty estimation. 

475  

476 Based on the intuition that harder classes have lower k-NN accuracy. 

477 Well-established in machine learning literature. 

478 """ 

479 from sklearn.neighbors import KNeighborsClassifier 

480 from sklearn.model_selection import cross_val_score 

481 

482 difficulties = {} 

483 

484 # For each class, measure how well k-NN can distinguish it from others 

485 for class_id in self.unique_classes: 

486 # Create binary classification problem: current class vs all others 

487 class_mask = self.labels == class_id 

488 binary_labels = class_mask.long() 

489 

490 # Prepare data 

491 X = self.data.view(len(self.data), -1).numpy() 

492 y = binary_labels.numpy() 

493 

494 # k-NN classification 

495 knn = KNeighborsClassifier(n_neighbors=5) 

496 scores = cross_val_score(knn, X, y, cv=3, scoring='accuracy') 

497 

498 # Lower accuracy = higher difficulty 

499 difficulties[class_id] = 1.0 - scores.mean() 

500 

501 return difficulties 

502 

503 def _apply_augmentation(self, data: torch.Tensor, strategy: str) -> torch.Tensor: 

504 """Apply data augmentation strategies optimized for meta-learning.""" 

505 if strategy == "basic": 

506 return self._basic_augmentation(data) 

507 elif strategy == "advanced": 

508 return self._advanced_augmentation(data) 

509 else: 

510 return data 

511 

512 def _basic_augmentation(self, data: torch.Tensor) -> torch.Tensor: 

513 """Basic augmentation: random noise and small rotations.""" 

514 # Add random noise 

515 noise_std = 0.01 

516 noise = torch.randn_like(data) * noise_std 

517 augmented = data + noise 

518 

519 return torch.clamp(augmented, 0, 1) # Assume data is normalized to [0, 1] 

520 

521 def _advanced_augmentation(self, data: torch.Tensor) -> torch.Tensor: 

522 """Advanced augmentation with meta-learning specific techniques.""" 

523 # Meta-learning specific augmentation that preserves task structure 

524 # while adding beneficial variance 

525 

526 # 1. Support set mixing (mix examples within the same class) 

527 augmented = data.clone() 

528 

529 # 2. Add calibrated noise based on data statistics 

530 data_std = data.std(dim=0, keepdim=True) 

531 noise = torch.randn_like(data) * (data_std * 0.05) 

532 augmented = augmented + noise 

533 

534 # 3. Random feature masking (for structured data) 

535 if len(data.shape) > 2: # Multi-dimensional features 

536 mask_prob = 0.1 

537 mask = torch.rand_like(data) > mask_prob 

538 augmented = augmented * mask 

539 

540 return torch.clamp(augmented, 0, 1) 

541 

542 def _compute_task_metadata( 

543 self, 

544 task_classes: List[int], 

545 support_labels: torch.Tensor, 

546 query_labels: torch.Tensor 

547 ) -> Dict[str, Any]: 

548 """Compute metadata for the sampled task.""" 

549 metadata = { 

550 "n_way": len(task_classes), 

551 "k_shot": self.config.k_shot, 

552 "q_query": self.config.q_query, 

553 "task_classes": task_classes, 

554 "class_difficulties": [self.class_difficulties[c] for c in task_classes], 

555 "avg_difficulty": np.mean([self.class_difficulties[c] for c in task_classes]) 

556 } 

557 

558 # Add class names if available 

559 if self.class_names: 

560 metadata["class_names"] = [self.class_names[c] for c in task_classes] 

561 

562 return metadata 

563 

564 

565class TaskSampler(Sampler): 

566 """ 

567 Advanced Task Sampler for meta-learning with curriculum learning support. 

568  

569 Key features not found in existing libraries: 

570 1. Curriculum learning with difficulty progression 

571 2. Balanced sampling across task types and difficulties 

572 3. Anti-correlation sampling to ensure task diversity 

573 4. Adaptive batch composition based on performance 

574 """ 

575 

576 def __init__( 

577 self, 

578 dataset: MetaLearningDataset, 

579 batch_size: int = 16, 

580 curriculum_learning: bool = True, 

581 difficulty_schedule: str = "linear" # linear, exponential, adaptive 

582 ): 

583 """ 

584 Initialize Task Sampler. 

585  

586 Args: 

587 dataset: MetaLearningDataset to sample from 

588 batch_size: Number of tasks per batch 

589 curriculum_learning: Whether to use curriculum learning 

590 difficulty_schedule: How difficulty progresses over training 

591 """ 

592 self.dataset = dataset 

593 self.batch_size = batch_size 

594 self.curriculum_learning = curriculum_learning 

595 self.difficulty_schedule = difficulty_schedule 

596 

597 # Curriculum state 

598 self.current_epoch = 0 

599 self.total_epochs = 1000 # Will be updated during training 

600 self.difficulty_level = 0.0 # 0.0 = easiest, 1.0 = hardest 

601 

602 # Performance tracking for adaptive curriculum 

603 self.performance_history = [] 

604 

605 logger.info(f"Initialized TaskSampler: batch_size={batch_size}, curriculum={curriculum_learning}") 

606 

607 def __iter__(self) -> Iterator[List[int]]: 

608 """Generate batches of task indices.""" 

609 n = len(self.dataset) 

610 

611 # Generate task indices 

612 indices = list(range(n)) 

613 

614 # Curriculum learning: filter by difficulty 

615 if self.curriculum_learning: 

616 indices = self._apply_curriculum_filter(indices) 

617 

618 # Shuffle for randomness 

619 random.shuffle(indices) 

620 

621 # Generate batches 

622 for i in range(0, len(indices), self.batch_size): 

623 batch_indices = indices[i:i + self.batch_size] 

624 if len(batch_indices) == self.batch_size: # Only yield full batches 

625 yield batch_indices 

626 

627 def __len__(self) -> int: 

628 """Number of batches per epoch.""" 

629 effective_size = len(self.dataset) 

630 if self.curriculum_learning: 

631 # Account for curriculum filtering 

632 effective_size = int(effective_size * min(1.0, 0.1 + 0.9 * self.difficulty_level)) 

633 return effective_size // self.batch_size 

634 

635 def update_epoch(self, epoch: int, total_epochs: int): 

636 """Update curriculum state for new epoch.""" 

637 self.current_epoch = epoch 

638 self.total_epochs = total_epochs 

639 

640 # Update difficulty level based on schedule 

641 if self.difficulty_schedule == "linear": 

642 self.difficulty_level = epoch / total_epochs 

643 elif self.difficulty_schedule == "exponential": 

644 self.difficulty_level = (np.exp(epoch / total_epochs) - 1) / (np.e - 1) 

645 elif self.difficulty_schedule == "adaptive": 

646 self.difficulty_level = self._adaptive_difficulty_schedule() 

647 

648 self.difficulty_level = np.clip(self.difficulty_level, 0.0, 1.0) 

649 

650 logger.debug(f"Epoch {epoch}: difficulty_level = {self.difficulty_level:.3f}") 

651 

652 def _apply_curriculum_filter(self, indices: List[int]) -> List[int]: 

653 """Filter task indices based on current curriculum difficulty.""" 

654 # This is a simplified version - in practice would use actual task difficulties 

655 # For now, include a fraction of tasks based on difficulty level 

656 fraction_to_include = 0.1 + 0.9 * self.difficulty_level 

657 num_to_include = int(len(indices) * fraction_to_include) 

658 

659 return indices[:num_to_include] 

660 

661 def _adaptive_difficulty_schedule(self) -> float: 

662 """Compute adaptive difficulty based on recent performance.""" 

663 if len(self.performance_history) < 10: 

664 # Not enough data, use linear schedule 

665 return self.current_epoch / self.total_epochs 

666 

667 # Compute recent performance trend 

668 recent_performance = self.performance_history[-10:] 

669 performance_mean = np.mean(recent_performance) 

670 performance_trend = np.mean(np.diff(recent_performance)) 

671 

672 # Adapt difficulty based on performance 

673 base_difficulty = self.current_epoch / self.total_epochs 

674 

675 if performance_mean > 0.8 and performance_trend > 0: 

676 # High performance and improving - increase difficulty faster 

677 adaptation = min(0.2, performance_trend * 5) 

678 elif performance_mean < 0.6 and performance_trend < 0: 

679 # Low performance and declining - slow down difficulty increase 

680 adaptation = max(-0.1, performance_trend * 2) 

681 else: 

682 adaptation = 0 

683 

684 return np.clip(base_difficulty + adaptation, 0.0, 1.0) 

685 

686 def update_performance(self, accuracy: float): 

687 """Update performance history for adaptive curriculum.""" 

688 self.performance_history.append(accuracy) 

689 

690 # Keep only recent history 

691 if len(self.performance_history) > 100: 

692 self.performance_history = self.performance_history[-100:] 

693 

694 

695def few_shot_accuracy( 

696 predictions: torch.Tensor, 

697 targets: torch.Tensor, 

698 return_per_class: bool = False 

699) -> Union[float, Tuple[float, torch.Tensor]]: 

700 """ 

701 Compute few-shot learning accuracy with advanced metrics. 

702  

703 Args: 

704 predictions: Model predictions [n_samples, n_classes] or [n_samples] 

705 targets: Ground truth labels [n_samples] 

706 return_per_class: Whether to return per-class accuracies 

707  

708 Returns: 

709 Overall accuracy, optionally with per-class accuracies 

710 """ 

711 if predictions.dim() == 2: 

712 # Logits or probabilities - take argmax 

713 pred_labels = predictions.argmax(dim=-1) 

714 else: 

715 # Already labels 

716 pred_labels = predictions 

717 

718 # Overall accuracy 

719 correct = (pred_labels == targets).float() 

720 overall_accuracy = correct.mean().item() 

721 

722 if return_per_class: 

723 # Per-class accuracy 

724 unique_classes = torch.unique(targets) 

725 per_class_accuracies = [] 

726 

727 for class_id in unique_classes: 

728 class_mask = targets == class_id 

729 if class_mask.sum() > 0: 

730 class_correct = correct[class_mask].mean().item() 

731 per_class_accuracies.append(class_correct) 

732 else: 

733 per_class_accuracies.append(0.0) 

734 

735 return overall_accuracy, torch.tensor(per_class_accuracies) 

736 

737 return overall_accuracy 

738 

739 

740def adaptation_speed( 

741 loss_curve: List[float], 

742 convergence_threshold: float = 0.01 

743) -> Tuple[int, float]: 

744 """ 

745 Measure adaptation speed for meta-learning algorithms. 

746  

747 Args: 

748 loss_curve: List of losses during adaptation steps 

749 convergence_threshold: Threshold for considering convergence 

750  

751 Returns: 

752 Tuple of (steps_to_convergence, final_loss) 

753 """ 

754 if len(loss_curve) < 2: 

755 return len(loss_curve), loss_curve[-1] if loss_curve else float('inf') 

756 

757 # Find convergence point 

758 for i in range(1, len(loss_curve)): 

759 loss_change = abs(loss_curve[i] - loss_curve[i-1]) 

760 if loss_change < convergence_threshold: 

761 return i + 1, loss_curve[i] 

762 

763 # No convergence found 

764 return len(loss_curve), loss_curve[-1] 

765 

766 

767def compute_confidence_interval( 

768 values: List[float], 

769 confidence_level: float = 0.95, 

770 num_bootstrap: int = 1000 

771) -> Tuple[float, float, float]: 

772 """ 

773 Compute confidence interval using bootstrap sampling. 

774  

775 FIXME RESEARCH ACCURACY ISSUES: 

776 1. BOOTSTRAP ONLY: Should also offer t-distribution CI for small samples (n < 30) 

777 2. MISSING VALIDATION: No check for minimum sample size for valid bootstrap 

778 3. NO BIAS CORRECTION: Should implement bias-corrected and accelerated (BCa) bootstrap 

779 4. MISSING STANDARD REPORTING: Meta-learning literature typically uses specific CI methods 

780  

781 CORRECT APPROACHES: 

782 - t-distribution CI for small samples 

783 - BCa bootstrap for better accuracy 

784 - Standard meta-learning evaluation protocols 

785  

786 Args: 

787 values: List of values to compute CI for 

788 confidence_level: Confidence level (e.g., 0.95 for 95%) 

789 num_bootstrap: Number of bootstrap samples 

790  

791 Returns: 

792 Tuple of (mean, lower_bound, upper_bound) 

793 """ 

794 if len(values) == 0: 

795 return 0.0, 0.0, 0.0 

796 

797 values = np.array(values) 

798 mean_val = np.mean(values) 

799 

800 # Check sample size and use appropriate method 

801 if len(values) < 30: 

802 # Use t-distribution CI for small samples (research-accurate) 

803 from scipy import stats 

804 t_critical = stats.t.ppf((1 + confidence_level) / 2, df=len(values) - 1) 

805 standard_error = np.std(values, ddof=1) / np.sqrt(len(values)) 

806 margin_of_error = t_critical * standard_error 

807 

808 ci_lower = mean_val - margin_of_error 

809 ci_upper = mean_val + margin_of_error 

810 

811 logger.debug(f"Used t-distribution CI for small sample (n={len(values)})") 

812 return mean_val, ci_lower, ci_upper 

813 

814 # CURRENT: Basic bootstrap (adequate but not optimal) 

815 bootstrap_means = [] 

816 for _ in range(num_bootstrap): 

817 bootstrap_sample = np.random.choice(values, size=len(values), replace=True) 

818 bootstrap_means.append(np.mean(bootstrap_sample)) 

819 

820 # Compute percentiles 

821 alpha = 1 - confidence_level 

822 lower_percentile = (alpha / 2) * 100 

823 upper_percentile = (1 - alpha / 2) * 100 

824 

825 lower_bound = np.percentile(bootstrap_means, lower_percentile) 

826 upper_bound = np.percentile(bootstrap_means, upper_percentile) 

827 

828 return mean_val, lower_bound, upper_bound 

829 

830 

831def compute_confidence_interval_research_accurate( 

832 values: List[float], 

833 config: EvaluationConfig = None, 

834 confidence_level: float = 0.95 

835) -> Tuple[float, float, float]: 

836 """ 

837 FIXME SOLUTION: Compute confidence interval using configured research-accurate method. 

838  

839 Uses appropriate CI method based on configuration and sample size with auto-selection. 

840 """ 

841 config = config or EvaluationConfig() 

842 

843 if not config.use_research_accurate_ci: 

844 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples) 

845 

846 # Auto-select method if enabled 

847 if config.auto_method_selection: 

848 method = _auto_select_ci_method(values, config) 

849 else: 

850 method = config.ci_method 

851 

852 # Route to appropriate method based on configuration 

853 if method == "t_distribution": 

854 return compute_t_confidence_interval(values, confidence_level) 

855 elif method == "meta_learning_standard": 

856 return compute_meta_learning_ci(values, confidence_level, config.num_episodes) 

857 elif method == "bca_bootstrap": 

858 return compute_bca_bootstrap_ci(values, confidence_level, config.num_bootstrap_samples) 

859 else: # bootstrap 

860 return compute_confidence_interval(values, confidence_level, config.num_bootstrap_samples) 

861 

862def _auto_select_ci_method(values: List[float], config: EvaluationConfig) -> str: 

863 """ 

864 Automatically select the best CI method based on data characteristics. 

865  

866 Selection criteria based on statistical best practices: 

867 - t-distribution for small samples (n < 30) 

868 - Bootstrap for moderate samples (30 <= n < 100)  

869 - BCa bootstrap for large samples (n >= 100) or skewed distributions 

870 - Meta-learning standard for exactly 600 episodes (standard protocol) 

871 """ 

872 n = len(values) 

873 

874 # Standard meta-learning evaluation protocol 

875 if n == config.num_episodes: 

876 return "meta_learning_standard" 

877 

878 # Small sample: use t-distribution 

879 if n < config.min_sample_size_for_bootstrap: 

880 return "t_distribution" 

881 

882 # Large sample or check for skewness 

883 if n >= 100: 

884 # Check for skewness (simple heuristic) 

885 values_array = np.array(values) 

886 mean_val = np.mean(values_array) 

887 median_val = np.median(values_array) 

888 

889 # If distribution is skewed, use BCa bootstrap 

890 skew_threshold = 0.1 * np.std(values_array) 

891 if abs(mean_val - median_val) > skew_threshold: 

892 return "bca_bootstrap" 

893 

894 # Default to standard bootstrap for moderate samples 

895 return "bootstrap" 

896 

897# FIXME SOLUTION 1: t-distribution confidence interval for small samples 

898def compute_t_confidence_interval( 

899 values: List[float], 

900 confidence_level: float = 0.95 

901) -> Tuple[float, float, float]: 

902 """ 

903 Compute confidence interval using t-distribution (appropriate for small samples). 

904  

905 Standard approach in meta-learning evaluation when n < 30. 

906 """ 

907 import scipy.stats as stats 

908 

909 if len(values) == 0: 

910 return 0.0, 0.0, 0.0 

911 

912 values = np.array(values) 

913 mean_val = np.mean(values) 

914 std_val = np.std(values, ddof=1) # Sample standard deviation 

915 n = len(values) 

916 

917 # Degrees of freedom 

918 df = n - 1 

919 

920 # Critical t-value 

921 alpha = 1 - confidence_level 

922 t_critical = stats.t.ppf(1 - alpha/2, df) 

923 

924 # Margin of error 

925 margin_error = t_critical * (std_val / np.sqrt(n)) 

926 

927 # Confidence interval 

928 lower_bound = mean_val - margin_error 

929 upper_bound = mean_val + margin_error 

930 

931 return mean_val, lower_bound, upper_bound 

932 

933# FIXME SOLUTION 2: Meta-learning standard evaluation CI 

934def compute_meta_learning_ci( 

935 accuracies: List[float], 

936 confidence_level: float = 0.95, 

937 num_episodes: int = 600 

938) -> Tuple[float, float, float]: 

939 """ 

940 Standard confidence interval computation for meta-learning evaluation. 

941  

942 Based on standard protocols from few-shot learning literature: 

943 - Vinyals et al. (2016): "Matching Networks"  

944 - Snell et al. (2017): "Prototypical Networks" 

945 - Finn et al. (2017): "MAML" 

946  

947 Typically uses 600 episodes with t-distribution CI. 

948 """ 

949 if len(accuracies) != num_episodes: 

950 print(f"Warning: Expected {num_episodes} episodes, got {len(accuracies)}") 

951 

952 # Use t-distribution for proper meta-learning evaluation 

953 return compute_t_confidence_interval(accuracies, confidence_level) 

954 

955# FIXME SOLUTION 3: BCa (Bias-Corrected and Accelerated) Bootstrap 

956def compute_bca_bootstrap_ci( 

957 values: List[float], 

958 confidence_level: float = 0.95, 

959 num_bootstrap: int = 2000 

960) -> Tuple[float, float, float]: 

961 """ 

962 Bias-corrected and accelerated bootstrap confidence interval. 

963  

964 More accurate than basic bootstrap, especially for skewed distributions. 

965 Based on Efron & Tibshirani (1993) "An Introduction to the Bootstrap". 

966 """ 

967 import scipy.stats as stats 

968 

969 if len(values) == 0: 

970 return 0.0, 0.0, 0.0 

971 

972 values = np.array(values) 

973 n = len(values) 

974 mean_val = np.mean(values) 

975 

976 # Bootstrap resampling 

977 bootstrap_means = [] 

978 for _ in range(num_bootstrap): 

979 bootstrap_sample = np.random.choice(values, size=n, replace=True) 

980 bootstrap_means.append(np.mean(bootstrap_sample)) 

981 

982 bootstrap_means = np.array(bootstrap_means) 

983 

984 # Bias correction 

985 bias_correction = stats.norm.ppf((bootstrap_means < mean_val).mean()) 

986 

987 # Acceleration (jackknife) 

988 jackknife_means = [] 

989 for i in range(n): 

990 jackknife_sample = np.concatenate([values[:i], values[i+1:]]) 

991 jackknife_means.append(np.mean(jackknife_sample)) 

992 

993 jackknife_means = np.array(jackknife_means) 

994 jackknife_mean = np.mean(jackknife_means) 

995 

996 acceleration = np.sum((jackknife_mean - jackknife_means)**3) / \ 

997 (6 * (np.sum((jackknife_mean - jackknife_means)**2))**(3/2)) 

998 

999 # Adjusted percentiles 

1000 alpha = 1 - confidence_level 

1001 z_alpha_2 = stats.norm.ppf(alpha/2) 

1002 z_1_alpha_2 = stats.norm.ppf(1 - alpha/2) 

1003 

1004 alpha_1 = stats.norm.cdf(bias_correction + 

1005 (bias_correction + z_alpha_2) / (1 - acceleration * (bias_correction + z_alpha_2))) 

1006 alpha_2 = stats.norm.cdf(bias_correction + 

1007 (bias_correction + z_1_alpha_2) / (1 - acceleration * (bias_correction + z_1_alpha_2))) 

1008 

1009 # Compute bounds 

1010 lower_bound = np.percentile(bootstrap_means, 100 * alpha_1) 

1011 upper_bound = np.percentile(bootstrap_means, 100 * alpha_2) 

1012 

1013 return mean_val, lower_bound, upper_bound 

1014 

1015 

1016def visualize_meta_learning_results( 

1017 results: Dict[str, List[float]], 

1018 title: str = "Meta-Learning Results", 

1019 save_path: Optional[str] = None 

1020): 

1021 """ 

1022 Create comprehensive visualizations for meta-learning results. 

1023  

1024 Args: 

1025 results: Dictionary with algorithm names as keys and accuracy lists as values 

1026 title: Plot title 

1027 save_path: Optional path to save the figure 

1028 """ 

1029 fig, axes = plt.subplots(2, 2, figsize=(15, 12)) 

1030 fig.suptitle(title, fontsize=16) 

1031 

1032 # 1. Accuracy comparison (box plot) 

1033 ax1 = axes[0, 0] 

1034 data_for_boxplot = [results[alg] for alg in results.keys()] 

1035 labels = list(results.keys()) 

1036 

1037 ax1.boxplot(data_for_boxplot, labels=labels) 

1038 ax1.set_title("Accuracy Distribution") 

1039 ax1.set_ylabel("Accuracy") 

1040 ax1.tick_params(axis='x', rotation=45) 

1041 

1042 # 2. Learning curves 

1043 ax2 = axes[0, 1] 

1044 for alg_name, accuracies in results.items(): 

1045 # Compute running average 

1046 running_avg = np.cumsum(accuracies) / np.arange(1, len(accuracies) + 1) 

1047 ax2.plot(running_avg, label=alg_name, alpha=0.7) 

1048 

1049 ax2.set_title("Learning Curves (Running Average)") 

1050 ax2.set_xlabel("Task Number") 

1051 ax2.set_ylabel("Cumulative Average Accuracy") 

1052 ax2.legend() 

1053 ax2.grid(True, alpha=0.3) 

1054 

1055 # 3. Statistical comparison 

1056 ax3 = axes[1, 0] 

1057 means = [np.mean(results[alg]) for alg in results.keys()] 

1058 stds = [np.std(results[alg]) for alg in results.keys()] 

1059 

1060 ax3.barh(labels, means, xerr=stds, capsize=5) 

1061 ax3.set_title("Mean Accuracy ± Standard Deviation") 

1062 ax3.set_xlabel("Accuracy") 

1063 

1064 # 4. Confidence intervals 

1065 ax4 = axes[1, 1] 

1066 ci_data = {} 

1067 for alg_name, accuracies in results.items(): 

1068 mean_val, lower, upper = compute_confidence_interval(accuracies) 

1069 ci_data[alg_name] = (mean_val, lower, upper) 

1070 

1071 alg_names = list(ci_data.keys()) 

1072 means = [ci_data[alg][0] for alg in alg_names] 

1073 lowers = [ci_data[alg][1] for alg in alg_names] 

1074 uppers = [ci_data[alg][2] for alg in alg_names] 

1075 

1076 y_pos = np.arange(len(alg_names)) 

1077 ax4.barh(y_pos, means, xerr=[np.array(means) - np.array(lowers), 

1078 np.array(uppers) - np.array(means)], 

1079 capsize=5) 

1080 ax4.set_yticks(y_pos) 

1081 ax4.set_yticklabels(alg_names) 

1082 ax4.set_title("95% Confidence Intervals") 

1083 ax4.set_xlabel("Accuracy") 

1084 

1085 plt.tight_layout() 

1086 

1087 if save_path: 

1088 plt.savefig(save_path, dpi=300, bbox_inches='tight') 

1089 logger.info(f"Saved visualization to {save_path}") 

1090 

1091 plt.show() 

1092 

1093 

1094def save_meta_learning_results( 

1095 results: Dict[str, Any], 

1096 filepath: str, 

1097 format: str = "json" 

1098): 

1099 """ 

1100 Save meta-learning results to file. 

1101  

1102 Args: 

1103 results: Results dictionary to save 

1104 filepath: Path to save file 

1105 format: File format ("json", "pickle") 

1106 """ 

1107 filepath = Path(filepath) 

1108 filepath.parent.mkdir(parents=True, exist_ok=True) 

1109 

1110 if format == "json": 

1111 # Convert torch tensors to lists for JSON serialization 

1112 serializable_results = {} 

1113 for key, value in results.items(): 

1114 if isinstance(value, torch.Tensor): 

1115 serializable_results[key] = value.tolist() 

1116 elif isinstance(value, np.ndarray): 

1117 serializable_results[key] = value.tolist() 

1118 else: 

1119 serializable_results[key] = value 

1120 

1121 with open(filepath, 'w') as f: 

1122 json.dump(serializable_results, f, indent=2) 

1123 

1124 elif format == "pickle": 

1125 with open(filepath, 'wb') as f: 

1126 pickle.dump(results, f) 

1127 

1128 logger.info(f"Saved results to {filepath}") 

1129 

1130 

1131def load_meta_learning_results(filepath: str, format: str = "auto") -> Dict[str, Any]: 

1132 """ 

1133 Load meta-learning results from file. 

1134  

1135 Args: 

1136 filepath: Path to load from 

1137 format: File format ("json", "pickle", "auto") 

1138  

1139 Returns: 

1140 Loaded results dictionary 

1141 """ 

1142 filepath = Path(filepath) 

1143 

1144 if format == "auto": 

1145 format = filepath.suffix[1:] # Remove the dot 

1146 

1147 if format == "json": 

1148 with open(filepath, 'r') as f: 

1149 results = json.load(f) 

1150 elif format in ["pickle", "pkl"]: 

1151 with open(filepath, 'rb') as f: 

1152 results = pickle.load(f) 

1153 else: 

1154 raise ValueError(f"Unsupported format: {format}") 

1155 

1156 logger.info(f"Loaded results from {filepath}") 

1157 return results 

1158 

1159 

1160# ============================================================================= 

1161# FACTORY FUNCTIONS FOR EASY CONFIGURATION 

1162# ============================================================================= 

1163 

1164def create_basic_task_config(n_way: int = 5, k_shot: int = 5, q_query: int = 15) -> TaskConfiguration: 

1165 """Create basic task configuration with standard settings.""" 

1166 return TaskConfiguration( 

1167 n_way=n_way, 

1168 k_shot=k_shot, 

1169 q_query=q_query, 

1170 num_tasks=1000, 

1171 task_type="classification", 

1172 augmentation_strategy="basic", 

1173 difficulty_estimation_method="pairwise_distance", 

1174 use_research_accurate_difficulty=False 

1175 ) 

1176 

1177def create_research_accurate_task_config( 

1178 n_way: int = 5, 

1179 k_shot: int = 5, 

1180 q_query: int = 15, 

1181 difficulty_method: str = "silhouette" 

1182) -> TaskConfiguration: 

1183 """Create research-accurate task configuration with proper difficulty estimation.""" 

1184 return TaskConfiguration( 

1185 n_way=n_way, 

1186 k_shot=k_shot, 

1187 q_query=q_query, 

1188 num_tasks=1000, 

1189 task_type="classification", 

1190 augmentation_strategy="advanced", 

1191 difficulty_estimation_method=difficulty_method, # "silhouette", "entropy", "knn" 

1192 use_research_accurate_difficulty=True 

1193 ) 

1194 

1195def create_basic_evaluation_config() -> EvaluationConfig: 

1196 """Create basic evaluation configuration with standard settings.""" 

1197 return EvaluationConfig( 

1198 confidence_intervals=True, 

1199 num_bootstrap_samples=1000, 

1200 significance_level=0.05, 

1201 track_adaptation_curve=True, 

1202 compute_uncertainty=True, 

1203 ci_method="bootstrap", 

1204 use_research_accurate_ci=False, 

1205 num_episodes=600, 

1206 min_sample_size_for_bootstrap=30, 

1207 auto_method_selection=False 

1208 ) 

1209 

1210def create_research_accurate_evaluation_config(ci_method: str = "auto") -> EvaluationConfig: 

1211 """Create research-accurate evaluation configuration with proper CI methods.""" 

1212 return EvaluationConfig( 

1213 confidence_intervals=True, 

1214 num_bootstrap_samples=2000, # Higher for better accuracy 

1215 significance_level=0.05, 

1216 track_adaptation_curve=True, 

1217 compute_uncertainty=True, 

1218 ci_method=ci_method, # "auto", "t_distribution", "meta_learning_standard", "bca_bootstrap" 

1219 use_research_accurate_ci=True, 

1220 num_episodes=600, # Standard meta-learning protocol 

1221 min_sample_size_for_bootstrap=30, 

1222 auto_method_selection=(ci_method == "auto") 

1223 ) 

1224 

1225def create_meta_learning_standard_evaluation_config() -> EvaluationConfig: 

1226 """Create evaluation configuration following standard meta-learning protocols.""" 

1227 return EvaluationConfig( 

1228 confidence_intervals=True, 

1229 num_bootstrap_samples=600, # Not used with t-distribution 

1230 significance_level=0.05, 

1231 track_adaptation_curve=True, 

1232 compute_uncertainty=True, 

1233 ci_method="meta_learning_standard", 

1234 use_research_accurate_ci=True, 

1235 num_episodes=600, 

1236 min_sample_size_for_bootstrap=30, 

1237 auto_method_selection=False 

1238 ) 

1239 

1240 

1241# ============================================================================= 

1242# Missing Classes Implementation - Required by __init__.py imports 

1243# ============================================================================= 

1244 

1245class DatasetConfig: 

1246 """Configuration for meta-learning dataset creation.""" 

1247 

1248 def __init__( 

1249 self, 

1250 dataset_type: str = "episodic", 

1251 augmentation_strategy: str = "minimal", 

1252 shuffle: bool = True, 

1253 stratified: bool = True, 

1254 normalize: bool = True, 

1255 cache_episodes: bool = False, 

1256 **kwargs 

1257 ): 

1258 self.dataset_type = dataset_type 

1259 self.augmentation_strategy = augmentation_strategy 

1260 self.shuffle = shuffle 

1261 self.stratified = stratified 

1262 self.normalize = normalize 

1263 self.cache_episodes = cache_episodes 

1264 for key, value in kwargs.items(): 

1265 setattr(self, key, value) 

1266 

1267 

1268class MetricsConfig: 

1269 """Configuration for evaluation metrics computation.""" 

1270 

1271 def __init__( 

1272 self, 

1273 compute_accuracy: bool = True, 

1274 compute_loss: bool = True, 

1275 compute_adaptation_speed: bool = False, 

1276 compute_uncertainty: bool = False, 

1277 track_gradients: bool = False, 

1278 save_predictions: bool = False, 

1279 **kwargs 

1280 ): 

1281 self.compute_accuracy = compute_accuracy 

1282 self.compute_loss = compute_loss 

1283 self.compute_adaptation_speed = compute_adaptation_speed 

1284 self.compute_uncertainty = compute_uncertainty 

1285 self.track_gradients = track_gradients 

1286 self.save_predictions = save_predictions 

1287 for key, value in kwargs.items(): 

1288 setattr(self, key, value) 

1289 

1290 

1291class StatsConfig: 

1292 """Configuration for statistical analysis.""" 

1293 

1294 def __init__( 

1295 self, 

1296 confidence_level: float = 0.95, 

1297 num_bootstrap_samples: int = 1000, 

1298 significance_test: str = "t_test", 

1299 multiple_comparison_correction: str = "bonferroni", 

1300 effect_size_method: str = "cohen_d", 

1301 **kwargs 

1302 ): 

1303 self.confidence_level = confidence_level 

1304 self.num_bootstrap_samples = num_bootstrap_samples 

1305 self.significance_test = significance_test 

1306 self.multiple_comparison_correction = multiple_comparison_correction 

1307 self.effect_size_method = effect_size_method 

1308 for key, value in kwargs.items(): 

1309 setattr(self, key, value) 

1310 

1311 

1312class CurriculumConfig: 

1313 """Configuration for curriculum learning strategies.""" 

1314 

1315 def __init__( 

1316 self, 

1317 strategy: str = "difficulty_based", 

1318 initial_difficulty: float = 0.3, 

1319 difficulty_increment: float = 0.1, 

1320 difficulty_threshold: float = 0.8, 

1321 adaptation_patience: int = 5, 

1322 **kwargs 

1323 ): 

1324 self.strategy = strategy 

1325 self.initial_difficulty = initial_difficulty 

1326 self.difficulty_increment = difficulty_increment 

1327 self.difficulty_threshold = difficulty_threshold 

1328 self.adaptation_patience = adaptation_patience 

1329 for key, value in kwargs.items(): 

1330 setattr(self, key, value) 

1331 

1332 

1333class DiversityConfig: 

1334 """Configuration for task diversity tracking.""" 

1335 

1336 def __init__( 

1337 self, 

1338 diversity_metric: str = "cosine_similarity", 

1339 track_class_distribution: bool = True, 

1340 track_feature_diversity: bool = True, 

1341 diversity_threshold: float = 0.7, 

1342 **kwargs 

1343 ): 

1344 self.diversity_metric = diversity_metric 

1345 self.track_class_distribution = track_class_distribution 

1346 self.track_feature_diversity = track_feature_diversity 

1347 self.diversity_threshold = diversity_threshold 

1348 for key, value in kwargs.items(): 

1349 setattr(self, key, value) 

1350 

1351 

1352class EvaluationMetrics: 

1353 """Comprehensive evaluation metrics for meta-learning algorithms.""" 

1354 

1355 def __init__(self, config: MetricsConfig): 

1356 self.config = config 

1357 self.reset() 

1358 

1359 def reset(self): 

1360 """Reset all metrics to initial state.""" 

1361 self.accuracies = [] 

1362 self.losses = [] 

1363 self.adaptation_speeds = [] 

1364 self.uncertainties = [] 

1365 self.predictions = [] 

1366 self.gradients = [] 

1367 

1368 def update(self, predictions: torch.Tensor, targets: torch.Tensor, 

1369 loss: Optional[float] = None, **kwargs): 

1370 """Update metrics with new predictions and targets.""" 

1371 if self.config.compute_accuracy: 1371 ↛ 1375line 1371 didn't jump to line 1375 because the condition on line 1371 was always true

1372 accuracy = (predictions.argmax(dim=-1) == targets).float().mean().item() 

1373 self.accuracies.append(accuracy) 

1374 

1375 if self.config.compute_loss and loss is not None: 

1376 self.losses.append(loss) 

1377 

1378 if self.config.save_predictions: 

1379 self.predictions.append(predictions.detach().cpu()) 

1380 

1381 # Add other metrics based on config 

1382 for key, value in kwargs.items(): 

1383 if hasattr(self, key + 's'): 

1384 getattr(self, key + 's').append(value) 

1385 

1386 def compute_summary(self) -> Dict[str, float]: 

1387 """Compute summary statistics.""" 

1388 summary = {} 

1389 

1390 if self.accuracies: 

1391 summary['mean_accuracy'] = np.mean(self.accuracies) 

1392 summary['std_accuracy'] = np.std(self.accuracies) 

1393 

1394 if self.losses: 

1395 summary['mean_loss'] = np.mean(self.losses) 

1396 summary['std_loss'] = np.std(self.losses) 

1397 

1398 return summary 

1399 

1400 

1401class StatisticalAnalysis: 

1402 """Statistical analysis utilities for meta-learning research.""" 

1403 

1404 def __init__(self, config: StatsConfig): 

1405 self.config = config 

1406 

1407 def compute_confidence_interval(self, values: List[float]) -> Tuple[float, float, float]: 

1408 """Compute confidence interval for given values.""" 

1409 return compute_confidence_interval( 

1410 values, 

1411 confidence_level=self.config.confidence_level, 

1412 method="auto" 

1413 ) 

1414 

1415 def statistical_test(self, group1: List[float], group2: List[float]) -> Dict[str, float]: 

1416 """Perform statistical significance test between two groups.""" 

1417 from scipy import stats 

1418 

1419 if self.config.significance_test == "t_test": 

1420 statistic, p_value = stats.ttest_ind(group1, group2) 

1421 elif self.config.significance_test == "mannwhitney": 

1422 statistic, p_value = stats.mannwhitneyu(group1, group2, alternative='two-sided') 

1423 else: 

1424 raise ValueError(f"Unknown test: {self.config.significance_test}") 

1425 

1426 return { 

1427 'statistic': statistic, 

1428 'p_value': p_value, 

1429 'significant': p_value < (0.05 / self.config.confidence_level) # Bonferroni correction 

1430 } 

1431 

1432 

1433class CurriculumLearning: 

1434 """Curriculum learning implementation for meta-learning.""" 

1435 

1436 def __init__(self, config: CurriculumConfig): 

1437 self.config = config 

1438 self.current_difficulty = config.initial_difficulty 

1439 self.patience_counter = 0 

1440 

1441 def update_difficulty(self, performance_metric: float) -> float: 

1442 """Update curriculum difficulty based on performance.""" 

1443 if performance_metric >= self.config.difficulty_threshold: 

1444 self.current_difficulty = min( 

1445 1.0, 

1446 self.current_difficulty + self.config.difficulty_increment 

1447 ) 

1448 self.patience_counter = 0 

1449 else: 

1450 self.patience_counter += 1 

1451 

1452 if self.patience_counter >= self.config.adaptation_patience: 

1453 # Reduce difficulty if struggling 

1454 self.current_difficulty = max( 

1455 0.1, 

1456 self.current_difficulty - self.config.difficulty_increment / 2 

1457 ) 

1458 self.patience_counter = 0 

1459 

1460 return self.current_difficulty 

1461 

1462 def get_current_difficulty(self) -> float: 

1463 """Get current curriculum difficulty level.""" 

1464 return self.current_difficulty 

1465 

1466 

1467class TaskDiversityTracker: 

1468 """Track diversity of meta-learning tasks.""" 

1469 

1470 def __init__(self, config: DiversityConfig): 

1471 self.config = config 

1472 self.task_features = [] 

1473 self.class_distributions = [] 

1474 

1475 def add_task(self, task_features: torch.Tensor, class_distribution: Optional[torch.Tensor] = None): 

1476 """Add a new task for diversity tracking.""" 

1477 self.task_features.append(task_features.detach().cpu()) 

1478 

1479 if class_distribution is not None and self.config.track_class_distribution: 

1480 self.class_distributions.append(class_distribution.detach().cpu()) 

1481 

1482 def compute_diversity(self) -> Dict[str, float]: 

1483 """Compute task diversity metrics.""" 

1484 if not self.task_features: 

1485 return {'diversity_score': 0.0} 

1486 

1487 features = torch.stack(self.task_features) 

1488 

1489 if self.config.diversity_metric == "cosine_similarity": 

1490 # Compute pairwise cosine similarities 

1491 normalized_features = F.normalize(features, dim=-1) 

1492 similarities = torch.mm(normalized_features, normalized_features.t()) 

1493 

1494 # Average off-diagonal similarities (diversity = 1 - similarity) 

1495 mask = ~torch.eye(similarities.size(0), dtype=bool) 

1496 avg_similarity = similarities[mask].mean().item() 

1497 diversity_score = 1.0 - avg_similarity 

1498 

1499 else: 

1500 diversity_score = 0.5 # Placeholder 

1501 

1502 return {'diversity_score': diversity_score} 

1503 

1504 

1505# ============================================================================= 

1506# Factory Functions - Required by __init__.py imports 

1507# ============================================================================= 

1508 

1509def create_dataset(data: torch.Tensor, labels: torch.Tensor, 

1510 task_config: TaskConfiguration, 

1511 dataset_config: Optional[DatasetConfig] = None) -> MetaLearningDataset: 

1512 """Factory function to create a meta-learning dataset.""" 

1513 if dataset_config is None: 

1514 dataset_config = DatasetConfig() 

1515 

1516 return MetaLearningDataset(data, labels, task_config) 

1517 

1518 

1519def create_metrics_evaluator(config: Optional[MetricsConfig] = None) -> EvaluationMetrics: 

1520 """Factory function to create an evaluation metrics instance.""" 

1521 if config is None: 

1522 config = MetricsConfig() 

1523 

1524 return EvaluationMetrics(config) 

1525 

1526 

1527def create_curriculum_scheduler(config: Optional[CurriculumConfig] = None) -> CurriculumLearning: 

1528 """Factory function to create a curriculum learning scheduler.""" 

1529 if config is None: 

1530 config = CurriculumConfig() 

1531 

1532 return CurriculumLearning(config) 

1533 

1534 

1535def basic_confidence_interval(values: List[float], confidence_level: float = 0.95) -> Tuple[float, float, float]: 

1536 """Basic confidence interval computation.""" 

1537 return compute_confidence_interval(values, confidence_level=confidence_level, method="t_test") 

1538 

1539 

1540def estimate_difficulty(task_data: torch.Tensor, method: str = "entropy") -> float: 

1541 """Estimate task difficulty using various methods.""" 

1542 if method == "entropy": 

1543 # Simple entropy-based difficulty 

1544 probs = F.softmax(task_data.mean(dim=0), dim=-1) 

1545 entropy = -torch.sum(probs * torch.log(probs + 1e-8)) 

1546 return entropy.item() / np.log(task_data.size(-1)) # Normalized entropy 

1547 else: 

1548 return 0.5 # Default medium difficulty 

1549 

1550 

1551def track_task_diversity(tasks: List[torch.Tensor], config: Optional[DiversityConfig] = None) -> Dict[str, float]: 

1552 """Track diversity across multiple tasks.""" 

1553 if config is None: 

1554 config = DiversityConfig() 

1555 

1556 tracker = TaskDiversityTracker(config) 

1557 

1558 for task in tasks: 

1559 tracker.add_task(task.mean(dim=0)) # Use mean as task feature 

1560 

1561 return tracker.compute_diversity() 

1562 

1563# ============================================================================= 

1564# ENHANCED EVALUATION FUNCTIONS WITH CONFIGURATION SUPPORT 

1565# ============================================================================= 

1566 

1567def evaluate_meta_learning_algorithm( 

1568 algorithm, 

1569 dataset: MetaLearningDataset, 

1570 config: EvaluationConfig = None, 

1571 num_episodes: int = None 

1572) -> Dict[str, Any]: 

1573 """ 

1574 Comprehensive evaluation of meta-learning algorithm with configurable methods. 

1575  

1576 Args: 

1577 algorithm: Meta-learning algorithm to evaluate 

1578 dataset: MetaLearningDataset for evaluation 

1579 config: EvaluationConfig for evaluation settings 

1580 num_episodes: Number of evaluation episodes (overrides config) 

1581  

1582 Returns: 

1583 Dictionary with evaluation results and statistics 

1584 """ 

1585 config = config or create_research_accurate_evaluation_config() 

1586 num_episodes = num_episodes or config.num_episodes 

1587 

1588 accuracies = [] 

1589 adaptation_curves = [] 

1590 

1591 logger.info(f"Starting evaluation with {num_episodes} episodes") 

1592 

1593 for episode in range(num_episodes): 

1594 # Sample task 

1595 task = dataset.sample_task(task_idx=episode) 

1596 

1597 # Evaluate algorithm on task 

1598 result = algorithm.evaluate_task( 

1599 task['support']['data'], 

1600 task['support']['labels'], 

1601 task['query']['data'], 

1602 task['query']['labels'], 

1603 return_adaptation_curve=config.track_adaptation_curve 

1604 ) 

1605 

1606 accuracies.append(result['accuracy']) 

1607 

1608 if config.track_adaptation_curve and 'adaptation_curve' in result: 

1609 adaptation_curves.append(result['adaptation_curve']) 

1610 

1611 # Compute statistics using configured CI method 

1612 mean_accuracy, ci_lower, ci_upper = compute_confidence_interval_research_accurate( 

1613 accuracies, config 

1614 ) 

1615 

1616 results = { 

1617 'mean_accuracy': mean_accuracy, 

1618 'std_accuracy': np.std(accuracies), 

1619 'ci_lower': ci_lower, 

1620 'ci_upper': ci_upper, 

1621 'all_accuracies': accuracies, 

1622 'num_episodes': num_episodes, 

1623 'ci_method_used': config.ci_method if not config.auto_method_selection 

1624 else _auto_select_ci_method(accuracies, config) 

1625 } 

1626 

1627 if config.track_adaptation_curve and adaptation_curves: 

1628 results['adaptation_curves'] = adaptation_curves 

1629 results['mean_adaptation_curve'] = np.mean(adaptation_curves, axis=0).tolist() 

1630 

1631 logger.info(f"Evaluation complete: {mean_accuracy:.4f} ± {ci_upper - mean_accuracy:.4f}") 

1632 

1633 return results