Coverage for src/meta_learning/meta_learning_modules/utils_modules/dataset_sampling.py: 0%

239 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:49 +0900

1""" 

2Dataset and Sampling Classes for Meta-Learning 

3============================================== 

4 

5Author: Benedict Chen (benedict@benedictchen.com) 

6 

7This module contains the core dataset and sampling functionality for meta-learning, 

8including advanced task sampling with curriculum learning support. 

9""" 

10 

11import torch 

12import torch.nn.functional as F 

13from torch.utils.data import Dataset, Sampler 

14from typing import Dict, List, Tuple, Optional, Any, Iterator, Union 

15import numpy as np 

16import random 

17import logging 

18from collections import defaultdict, Counter 

19 

20from .configurations import TaskConfiguration 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25class MetaLearningDataset(Dataset): 

26 """ 

27 Advanced Meta-Learning Dataset with sophisticated task sampling. 

28  

29 Key improvements over existing libraries: 

30 1. Hierarchical task organization with difficulty levels 

31 2. Balanced task sampling across domains and difficulties 

32 3. Dynamic task generation with curriculum learning 

33 4. Advanced data augmentation strategies for meta-learning 

34 5. Task similarity tracking and diverse sampling 

35 """ 

36 

37 def __init__( 

38 self, 

39 data: torch.Tensor, 

40 labels: torch.Tensor, 

41 task_config: TaskConfiguration = None, 

42 class_names: Optional[List[str]] = None, 

43 domain_labels: Optional[torch.Tensor] = None 

44 ): 

45 """ 

46 Initialize Meta-Learning Dataset. 

47  

48 Args: 

49 data: Input data [n_samples, ...] 

50 labels: Class labels [n_samples] 

51 task_config: Task configuration 

52 class_names: Optional class names for interpretability 

53 domain_labels: Optional domain labels for cross-domain tasks 

54 """ 

55 self.data = data 

56 self.labels = labels 

57 self.config = task_config or TaskConfiguration() 

58 self.class_names = class_names 

59 self.domain_labels = domain_labels 

60 

61 # Organize data by class for efficient sampling 

62 self.class_to_indices = defaultdict(list) 

63 for idx, label in enumerate(labels): 

64 self.class_to_indices[label.item()].append(idx) 

65 

66 self.unique_classes = list(self.class_to_indices.keys()) 

67 self.num_classes = len(self.unique_classes) 

68 

69 # Task history for diversity tracking 

70 self.task_history = [] 

71 self.class_usage_count = Counter() 

72 

73 # Difficulty estimation using configured method 

74 if self.config.use_research_accurate_difficulty: 

75 self.class_difficulties = self._estimate_class_difficulties_research_accurate() 

76 else: 

77 self.class_difficulties = self._estimate_class_difficulties() 

78 

79 logger.info(f"Initialized MetaLearningDataset: {self.num_classes} classes, {len(data)} samples") 

80 

81 def __len__(self) -> int: 

82 """Number of possible tasks (virtually infinite for meta-learning).""" 

83 return self.config.num_tasks 

84 

85 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 

86 """ 

87 Sample a meta-learning task. 

88  

89 Returns: 

90 Dictionary containing support and query sets with labels 

91 """ 

92 task = self.sample_task(task_idx=idx) 

93 return task 

94 

95 def sample_task( 

96 self, 

97 task_idx: Optional[int] = None, 

98 specified_classes: Optional[List[int]] = None, 

99 difficulty_level: Optional[str] = None 

100 ) -> Dict[str, torch.Tensor]: 

101 """ 

102 Sample a single meta-learning task with advanced strategies. 

103  

104 Args: 

105 task_idx: Optional task index for reproducibility 

106 specified_classes: Specific classes to use (overrides sampling) 

107 difficulty_level: "easy", "medium", "hard", or None for automatic 

108  

109 Returns: 

110 Task dictionary with support/query sets and metadata 

111 """ 

112 # Set random seed for reproducible task sampling 

113 if task_idx is not None: 

114 torch.manual_seed(42 + task_idx) 

115 np.random.seed(42 + task_idx) 

116 

117 # Select classes for this task 

118 if specified_classes: 

119 task_classes = specified_classes 

120 else: 

121 task_classes = self._sample_task_classes(difficulty_level) 

122 

123 # Sample support and query sets 

124 support_data, support_labels, query_data, query_labels = self._sample_support_query( 

125 task_classes 

126 ) 

127 

128 # Apply data augmentation 

129 if self.config.augmentation_strategy != "none": 

130 support_data = self._apply_augmentation(support_data, self.config.augmentation_strategy) 

131 

132 # Update task history and class usage 

133 self.task_history.append(task_classes) 

134 for class_id in task_classes: 

135 self.class_usage_count[class_id] += 1 

136 

137 # Compute task metadata 

138 task_metadata = self._compute_task_metadata(task_classes, support_labels, query_labels) 

139 

140 return { 

141 "support": { 

142 "data": support_data, 

143 "labels": support_labels 

144 }, 

145 "query": { 

146 "data": query_data, 

147 "labels": query_labels 

148 }, 

149 "task_classes": torch.tensor(task_classes), 

150 "metadata": task_metadata 

151 } 

152 

153 def _sample_task_classes(self, difficulty_level: Optional[str] = None) -> List[int]: 

154 """Sample classes for a task with diversity and difficulty control.""" 

155 if difficulty_level: 

156 # Filter classes by difficulty 

157 if difficulty_level == "easy": 

158 candidate_classes = [c for c in self.unique_classes 

159 if self.class_difficulties[c] < 0.3] 

160 elif difficulty_level == "medium": 

161 candidate_classes = [c for c in self.unique_classes 

162 if 0.3 <= self.class_difficulties[c] < 0.7] 

163 elif difficulty_level == "hard": 

164 candidate_classes = [c for c in self.unique_classes 

165 if self.class_difficulties[c] >= 0.7] 

166 else: 

167 candidate_classes = self.unique_classes 

168 else: 

169 candidate_classes = self.unique_classes 

170 

171 # Ensure we have enough classes 

172 if len(candidate_classes) < self.config.n_way: 

173 candidate_classes = self.unique_classes 

174 

175 # Diversity-aware sampling (prefer less used classes) 

176 class_weights = [] 

177 for class_id in candidate_classes: 

178 # Inverse frequency weighting for diversity 

179 usage_count = self.class_usage_count.get(class_id, 0) 

180 weight = 1.0 / (1.0 + usage_count) 

181 class_weights.append(weight) 

182 

183 # Normalize weights 

184 class_weights = np.array(class_weights) 

185 class_weights = class_weights / class_weights.sum() 

186 

187 # Sample classes 

188 selected_indices = np.random.choice( 

189 len(candidate_classes), 

190 size=self.config.n_way, 

191 replace=False, 

192 p=class_weights 

193 ) 

194 

195 return [candidate_classes[i] for i in selected_indices] 

196 

197 def _sample_support_query( 

198 self, 

199 task_classes: List[int] 

200 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 

201 """Sample support and query sets for given classes.""" 

202 support_data = [] 

203 support_labels = [] 

204 query_data = [] 

205 query_labels = [] 

206 

207 for new_label, original_class in enumerate(task_classes): 

208 # Get indices for this class 

209 class_indices = self.class_to_indices[original_class] 

210 

211 # Ensure we have enough samples 

212 total_needed = self.config.k_shot + self.config.q_query 

213 if len(class_indices) < total_needed: 

214 # Sample with replacement if necessary 

215 selected_indices = np.random.choice( 

216 class_indices, size=total_needed, replace=True 

217 ) 

218 else: 

219 selected_indices = np.random.choice( 

220 class_indices, size=total_needed, replace=False 

221 ) 

222 

223 # Split into support and query 

224 support_indices = selected_indices[:self.config.k_shot] 

225 query_indices = selected_indices[self.config.k_shot:] 

226 

227 # Collect support set 

228 for idx in support_indices: 

229 support_data.append(self.data[idx]) 

230 support_labels.append(new_label) 

231 

232 # Collect query set 

233 for idx in query_indices: 

234 query_data.append(self.data[idx]) 

235 query_labels.append(new_label) 

236 

237 return ( 

238 torch.stack(support_data), 

239 torch.tensor(support_labels), 

240 torch.stack(query_data), 

241 torch.tensor(query_labels) 

242 ) 

243 

244 def _estimate_class_difficulties(self) -> Dict[int, float]: 

245 """ 

246 Estimate difficulty of each class based on intra-class variance. 

247  

248 FIXME RESEARCH ACCURACY ISSUES: 

249 1. ARBITRARY DIFFICULTY METRIC: No research basis for using mean pairwise distance as difficulty 

250 2. INEFFICIENT COMPUTATION: O(n²) complexity for pairwise distance calculation 

251 3. MISSING ESTABLISHED METRICS: Should use research-validated difficulty measures 

252 4. NO COMPARISON TO BASELINES: Not comparing to standard difficulty estimation methods 

253  

254 BETTER APPROACHES from research: 

255 """ 

256 difficulties = {} 

257 

258 for class_id, indices in self.class_to_indices.items(): 

259 if len(indices) > 1: 

260 class_data = self.data[indices] 

261 

262 # CURRENT (PROBLEMATIC): Arbitrary pairwise distance measure 

263 flattened_data = class_data.view(len(class_data), -1) 

264 distances = torch.cdist(flattened_data, flattened_data) 

265 mean_distance = distances.sum() / (len(distances) ** 2 - len(distances)) 

266 difficulties[class_id] = mean_distance.item() 

267 else: 

268 difficulties[class_id] = 0.5 # Default medium difficulty 

269 

270 # Normalize difficulties to [0, 1] 

271 if difficulties: 

272 max_diff = max(difficulties.values()) 

273 min_diff = min(difficulties.values()) 

274 if max_diff > min_diff: 

275 for class_id in difficulties: 

276 difficulties[class_id] = (difficulties[class_id] - min_diff) / (max_diff - min_diff) 

277 

278 return difficulties 

279 

280 def _estimate_class_difficulties_research_accurate(self) -> Dict[int, float]: 

281 """ 

282 Route to appropriate research-accurate difficulty estimation method based on configuration. 

283 """ 

284 if self.config.difficulty_estimation_method == "silhouette": 

285 return self._estimate_class_difficulty_silhouette() 

286 elif self.config.difficulty_estimation_method == "entropy": 

287 return self._estimate_class_difficulty_entropy() 

288 elif self.config.difficulty_estimation_method == "knn": 

289 return self._estimate_class_difficulty_knn() 

290 else: # default to pairwise_distance 

291 return self._estimate_class_difficulties() 

292 

293 def _estimate_class_difficulty_silhouette(self) -> Dict[int, float]: 

294 """ 

295 FIXME SOLUTION 1: Use Silhouette Score for class difficulty estimation. 

296  

297 Based on "Silhouette: a graphical aid to the interpretation and validation of cluster analysis" (1987) 

298 Silhouette score measures how well-separated classes are. 

299 """ 

300 from sklearn.metrics import silhouette_samples 

301 

302 difficulties = {} 

303 all_data = self.data.view(len(self.data), -1).numpy() 

304 all_labels = self.labels.numpy() 

305 

306 # Compute silhouette scores for all samples 

307 silhouette_scores = silhouette_samples(all_data, all_labels) 

308 

309 # Average silhouette score per class (lower = more difficult) 

310 for class_id in self.unique_classes: 

311 class_mask = all_labels == class_id 

312 class_silhouette = silhouette_scores[class_mask].mean() 

313 

314 # Convert to difficulty (1 - silhouette, normalized to [0, 1]) 

315 difficulties[class_id] = 1.0 - (class_silhouette + 1.0) / 2.0 

316 

317 return difficulties 

318 

319 def _estimate_class_difficulty_entropy(self) -> Dict[int, float]: 

320 """ 

321 FIXME SOLUTION 2: Use feature entropy for difficulty estimation. 

322  

323 Classes with higher feature entropy are typically more difficult. 

324 Common approach in few-shot learning literature. 

325 """ 

326 difficulties = {} 

327 

328 for class_id, indices in self.class_to_indices.items(): 

329 if len(indices) > 1: 

330 class_data = self.data[indices] 

331 

332 # Compute feature-wise entropy 

333 flattened_data = class_data.view(len(class_data), -1) 

334 

335 # Discretize features for entropy calculation 

336 discretized = torch.floor(flattened_data * 10) / 10 # Simple binning 

337 

338 # Compute entropy for each feature dimension 

339 entropies = [] 

340 for feature_dim in range(discretized.shape[1]): 

341 feature_values = discretized[:, feature_dim] 

342 unique_vals, counts = torch.unique(feature_values, return_counts=True) 

343 probs = counts.float() / len(feature_values) 

344 entropy = -torch.sum(probs * torch.log(probs + 1e-8)) 

345 entropies.append(entropy.item()) 

346 

347 # Average entropy as difficulty measure 

348 difficulties[class_id] = np.mean(entropies) 

349 else: 

350 difficulties[class_id] = 0.5 

351 

352 return difficulties 

353 

354 def _estimate_class_difficulty_knn(self) -> Dict[int, float]: 

355 """ 

356 FIXME SOLUTION 3: Use k-NN classification accuracy for difficulty estimation. 

357  

358 Based on the intuition that harder classes have lower k-NN accuracy. 

359 Well-established in machine learning literature. 

360 """ 

361 from sklearn.neighbors import KNeighborsClassifier 

362 from sklearn.model_selection import cross_val_score 

363 

364 difficulties = {} 

365 

366 # For each class, measure how well k-NN can distinguish it from others 

367 for class_id in self.unique_classes: 

368 # Create binary classification problem: current class vs all others 

369 class_mask = self.labels == class_id 

370 binary_labels = class_mask.long() 

371 

372 # Prepare data 

373 X = self.data.view(len(self.data), -1).numpy() 

374 y = binary_labels.numpy() 

375 

376 # k-NN classification 

377 knn = KNeighborsClassifier(n_neighbors=5) 

378 scores = cross_val_score(knn, X, y, cv=3, scoring='accuracy') 

379 

380 # Lower accuracy = higher difficulty 

381 difficulties[class_id] = 1.0 - scores.mean() 

382 

383 return difficulties 

384 

385 def _apply_augmentation(self, data: torch.Tensor, strategy: str) -> torch.Tensor: 

386 """Apply data augmentation strategies optimized for meta-learning.""" 

387 if strategy == "basic": 

388 return self._basic_augmentation(data) 

389 elif strategy == "advanced": 

390 return self._advanced_augmentation(data) 

391 else: 

392 return data 

393 

394 def _basic_augmentation(self, data: torch.Tensor) -> torch.Tensor: 

395 """Basic augmentation: random noise and small rotations.""" 

396 # Add random noise 

397 noise_std = 0.01 

398 noise = torch.randn_like(data) * noise_std 

399 augmented = data + noise 

400 

401 return torch.clamp(augmented, 0, 1) # Assume data is normalized to [0, 1] 

402 

403 def _advanced_augmentation(self, data: torch.Tensor) -> torch.Tensor: 

404 """Advanced augmentation with meta-learning specific techniques.""" 

405 # Meta-learning specific augmentation that preserves task structure 

406 # while adding beneficial variance 

407 

408 # 1. Support set mixing (mix examples within the same class) 

409 augmented = data.clone() 

410 

411 # 2. Add calibrated noise based on data statistics 

412 data_std = data.std(dim=0, keepdim=True) 

413 noise = torch.randn_like(data) * (data_std * 0.05) 

414 augmented = augmented + noise 

415 

416 # 3. Random feature masking (for structured data) 

417 if len(data.shape) > 2: # Multi-dimensional features 

418 mask_prob = 0.1 

419 mask = torch.rand_like(data) > mask_prob 

420 augmented = augmented * mask 

421 

422 return torch.clamp(augmented, 0, 1) 

423 

424 def _compute_task_metadata( 

425 self, 

426 task_classes: List[int], 

427 support_labels: torch.Tensor, 

428 query_labels: torch.Tensor 

429 ) -> Dict[str, Any]: 

430 """Compute metadata for the sampled task.""" 

431 metadata = { 

432 "n_way": len(task_classes), 

433 "k_shot": self.config.k_shot, 

434 "q_query": self.config.q_query, 

435 "task_classes": task_classes, 

436 "class_difficulties": [self.class_difficulties[c] for c in task_classes], 

437 "avg_difficulty": np.mean([self.class_difficulties[c] for c in task_classes]) 

438 } 

439 

440 # Add class names if available 

441 if self.class_names: 

442 metadata["class_names"] = [self.class_names[c] for c in task_classes] 

443 

444 return metadata 

445 

446 

447class TaskSampler(Sampler): 

448 """ 

449 Advanced Task Sampler for meta-learning with curriculum learning support. 

450  

451 Key features not found in existing libraries: 

452 1. Curriculum learning with difficulty progression 

453 2. Balanced sampling across task types and difficulties 

454 3. Anti-correlation sampling to ensure task diversity 

455 4. Adaptive batch composition based on performance 

456 """ 

457 

458 def __init__( 

459 self, 

460 dataset: MetaLearningDataset, 

461 batch_size: int = 16, 

462 curriculum_learning: bool = True, 

463 difficulty_schedule: str = "linear" # linear, exponential, adaptive 

464 ): 

465 """ 

466 Initialize Task Sampler. 

467  

468 Args: 

469 dataset: MetaLearningDataset to sample from 

470 batch_size: Number of tasks per batch 

471 curriculum_learning: Whether to use curriculum learning 

472 difficulty_schedule: How difficulty progresses over training 

473 """ 

474 self.dataset = dataset 

475 self.batch_size = batch_size 

476 self.curriculum_learning = curriculum_learning 

477 self.difficulty_schedule = difficulty_schedule 

478 

479 # Curriculum state 

480 self.current_epoch = 0 

481 self.total_epochs = 1000 # Will be updated during training 

482 self.difficulty_level = 0.0 # 0.0 = easiest, 1.0 = hardest 

483 

484 # Performance tracking for adaptive curriculum 

485 self.performance_history = [] 

486 

487 logger.info(f"Initialized TaskSampler: batch_size={batch_size}, curriculum={curriculum_learning}") 

488 

489 def __iter__(self) -> Iterator[List[int]]: 

490 """Generate batches of task indices.""" 

491 n = len(self.dataset) 

492 

493 # Generate task indices 

494 indices = list(range(n)) 

495 

496 # Curriculum learning: filter by difficulty 

497 if self.curriculum_learning: 

498 indices = self._apply_curriculum_filter(indices) 

499 

500 # Shuffle for randomness 

501 random.shuffle(indices) 

502 

503 # Generate batches 

504 for i in range(0, len(indices), self.batch_size): 

505 batch_indices = indices[i:i + self.batch_size] 

506 if len(batch_indices) == self.batch_size: # Only yield full batches 

507 yield batch_indices 

508 

509 def __len__(self) -> int: 

510 """Number of batches per epoch.""" 

511 effective_size = len(self.dataset) 

512 if self.curriculum_learning: 

513 # Account for curriculum filtering 

514 effective_size = int(effective_size * min(1.0, 0.1 + 0.9 * self.difficulty_level)) 

515 return effective_size // self.batch_size 

516 

517 def update_epoch(self, epoch: int, total_epochs: int): 

518 """Update curriculum state for new epoch.""" 

519 self.current_epoch = epoch 

520 self.total_epochs = total_epochs 

521 

522 # Update difficulty level based on schedule 

523 if self.difficulty_schedule == "linear": 

524 self.difficulty_level = epoch / total_epochs 

525 elif self.difficulty_schedule == "exponential": 

526 self.difficulty_level = (np.exp(epoch / total_epochs) - 1) / (np.e - 1) 

527 elif self.difficulty_schedule == "adaptive": 

528 self.difficulty_level = self._adaptive_difficulty_schedule() 

529 

530 self.difficulty_level = np.clip(self.difficulty_level, 0.0, 1.0) 

531 

532 logger.debug(f"Epoch {epoch}: difficulty_level = {self.difficulty_level:.3f}") 

533 

534 def _apply_curriculum_filter(self, indices: List[int]) -> List[int]: 

535 """Filter task indices based on current curriculum difficulty.""" 

536 # This is a simplified version - in practice would use actual task difficulties 

537 # For now, include a fraction of tasks based on difficulty level 

538 fraction_to_include = 0.1 + 0.9 * self.difficulty_level 

539 num_to_include = int(len(indices) * fraction_to_include) 

540 

541 return indices[:num_to_include] 

542 

543 def _adaptive_difficulty_schedule(self) -> float: 

544 """Compute adaptive difficulty based on recent performance.""" 

545 if len(self.performance_history) < 10: 

546 # Not enough data, use linear schedule 

547 return self.current_epoch / self.total_epochs 

548 

549 # Compute recent performance trend 

550 recent_performance = self.performance_history[-10:] 

551 performance_mean = np.mean(recent_performance) 

552 performance_trend = np.mean(np.diff(recent_performance)) 

553 

554 # Adapt difficulty based on performance 

555 base_difficulty = self.current_epoch / self.total_epochs 

556 

557 if performance_mean > 0.8 and performance_trend > 0: 

558 # High performance and improving - increase difficulty faster 

559 adaptation = min(0.2, performance_trend * 5) 

560 elif performance_mean < 0.6 and performance_trend < 0: 

561 # Low performance and declining - slow down difficulty increase 

562 adaptation = max(-0.1, performance_trend * 2) 

563 else: 

564 adaptation = 0 

565 

566 return np.clip(base_difficulty + adaptation, 0.0, 1.0) 

567 

568 def update_performance(self, accuracy: float): 

569 """Update performance history for adaptive curriculum.""" 

570 self.performance_history.append(accuracy) 

571 

572 # Keep only recent history 

573 if len(self.performance_history) > 100: 

574 self.performance_history = self.performance_history[-100:]