Coverage for src/meta_learning/meta_learning_modules/utils_modules/dataset_sampling.py: 0%
239 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
« prev ^ index » next coverage.py v7.10.5, created at 2025-09-03 12:49 +0900
1"""
2Dataset and Sampling Classes for Meta-Learning
3==============================================
5Author: Benedict Chen (benedict@benedictchen.com)
7This module contains the core dataset and sampling functionality for meta-learning,
8including advanced task sampling with curriculum learning support.
9"""
11import torch
12import torch.nn.functional as F
13from torch.utils.data import Dataset, Sampler
14from typing import Dict, List, Tuple, Optional, Any, Iterator, Union
15import numpy as np
16import random
17import logging
18from collections import defaultdict, Counter
20from .configurations import TaskConfiguration
22logger = logging.getLogger(__name__)
25class MetaLearningDataset(Dataset):
26 """
27 Advanced Meta-Learning Dataset with sophisticated task sampling.
29 Key improvements over existing libraries:
30 1. Hierarchical task organization with difficulty levels
31 2. Balanced task sampling across domains and difficulties
32 3. Dynamic task generation with curriculum learning
33 4. Advanced data augmentation strategies for meta-learning
34 5. Task similarity tracking and diverse sampling
35 """
37 def __init__(
38 self,
39 data: torch.Tensor,
40 labels: torch.Tensor,
41 task_config: TaskConfiguration = None,
42 class_names: Optional[List[str]] = None,
43 domain_labels: Optional[torch.Tensor] = None
44 ):
45 """
46 Initialize Meta-Learning Dataset.
48 Args:
49 data: Input data [n_samples, ...]
50 labels: Class labels [n_samples]
51 task_config: Task configuration
52 class_names: Optional class names for interpretability
53 domain_labels: Optional domain labels for cross-domain tasks
54 """
55 self.data = data
56 self.labels = labels
57 self.config = task_config or TaskConfiguration()
58 self.class_names = class_names
59 self.domain_labels = domain_labels
61 # Organize data by class for efficient sampling
62 self.class_to_indices = defaultdict(list)
63 for idx, label in enumerate(labels):
64 self.class_to_indices[label.item()].append(idx)
66 self.unique_classes = list(self.class_to_indices.keys())
67 self.num_classes = len(self.unique_classes)
69 # Task history for diversity tracking
70 self.task_history = []
71 self.class_usage_count = Counter()
73 # Difficulty estimation using configured method
74 if self.config.use_research_accurate_difficulty:
75 self.class_difficulties = self._estimate_class_difficulties_research_accurate()
76 else:
77 self.class_difficulties = self._estimate_class_difficulties()
79 logger.info(f"Initialized MetaLearningDataset: {self.num_classes} classes, {len(data)} samples")
81 def __len__(self) -> int:
82 """Number of possible tasks (virtually infinite for meta-learning)."""
83 return self.config.num_tasks
85 def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
86 """
87 Sample a meta-learning task.
89 Returns:
90 Dictionary containing support and query sets with labels
91 """
92 task = self.sample_task(task_idx=idx)
93 return task
95 def sample_task(
96 self,
97 task_idx: Optional[int] = None,
98 specified_classes: Optional[List[int]] = None,
99 difficulty_level: Optional[str] = None
100 ) -> Dict[str, torch.Tensor]:
101 """
102 Sample a single meta-learning task with advanced strategies.
104 Args:
105 task_idx: Optional task index for reproducibility
106 specified_classes: Specific classes to use (overrides sampling)
107 difficulty_level: "easy", "medium", "hard", or None for automatic
109 Returns:
110 Task dictionary with support/query sets and metadata
111 """
112 # Set random seed for reproducible task sampling
113 if task_idx is not None:
114 torch.manual_seed(42 + task_idx)
115 np.random.seed(42 + task_idx)
117 # Select classes for this task
118 if specified_classes:
119 task_classes = specified_classes
120 else:
121 task_classes = self._sample_task_classes(difficulty_level)
123 # Sample support and query sets
124 support_data, support_labels, query_data, query_labels = self._sample_support_query(
125 task_classes
126 )
128 # Apply data augmentation
129 if self.config.augmentation_strategy != "none":
130 support_data = self._apply_augmentation(support_data, self.config.augmentation_strategy)
132 # Update task history and class usage
133 self.task_history.append(task_classes)
134 for class_id in task_classes:
135 self.class_usage_count[class_id] += 1
137 # Compute task metadata
138 task_metadata = self._compute_task_metadata(task_classes, support_labels, query_labels)
140 return {
141 "support": {
142 "data": support_data,
143 "labels": support_labels
144 },
145 "query": {
146 "data": query_data,
147 "labels": query_labels
148 },
149 "task_classes": torch.tensor(task_classes),
150 "metadata": task_metadata
151 }
153 def _sample_task_classes(self, difficulty_level: Optional[str] = None) -> List[int]:
154 """Sample classes for a task with diversity and difficulty control."""
155 if difficulty_level:
156 # Filter classes by difficulty
157 if difficulty_level == "easy":
158 candidate_classes = [c for c in self.unique_classes
159 if self.class_difficulties[c] < 0.3]
160 elif difficulty_level == "medium":
161 candidate_classes = [c for c in self.unique_classes
162 if 0.3 <= self.class_difficulties[c] < 0.7]
163 elif difficulty_level == "hard":
164 candidate_classes = [c for c in self.unique_classes
165 if self.class_difficulties[c] >= 0.7]
166 else:
167 candidate_classes = self.unique_classes
168 else:
169 candidate_classes = self.unique_classes
171 # Ensure we have enough classes
172 if len(candidate_classes) < self.config.n_way:
173 candidate_classes = self.unique_classes
175 # Diversity-aware sampling (prefer less used classes)
176 class_weights = []
177 for class_id in candidate_classes:
178 # Inverse frequency weighting for diversity
179 usage_count = self.class_usage_count.get(class_id, 0)
180 weight = 1.0 / (1.0 + usage_count)
181 class_weights.append(weight)
183 # Normalize weights
184 class_weights = np.array(class_weights)
185 class_weights = class_weights / class_weights.sum()
187 # Sample classes
188 selected_indices = np.random.choice(
189 len(candidate_classes),
190 size=self.config.n_way,
191 replace=False,
192 p=class_weights
193 )
195 return [candidate_classes[i] for i in selected_indices]
197 def _sample_support_query(
198 self,
199 task_classes: List[int]
200 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
201 """Sample support and query sets for given classes."""
202 support_data = []
203 support_labels = []
204 query_data = []
205 query_labels = []
207 for new_label, original_class in enumerate(task_classes):
208 # Get indices for this class
209 class_indices = self.class_to_indices[original_class]
211 # Ensure we have enough samples
212 total_needed = self.config.k_shot + self.config.q_query
213 if len(class_indices) < total_needed:
214 # Sample with replacement if necessary
215 selected_indices = np.random.choice(
216 class_indices, size=total_needed, replace=True
217 )
218 else:
219 selected_indices = np.random.choice(
220 class_indices, size=total_needed, replace=False
221 )
223 # Split into support and query
224 support_indices = selected_indices[:self.config.k_shot]
225 query_indices = selected_indices[self.config.k_shot:]
227 # Collect support set
228 for idx in support_indices:
229 support_data.append(self.data[idx])
230 support_labels.append(new_label)
232 # Collect query set
233 for idx in query_indices:
234 query_data.append(self.data[idx])
235 query_labels.append(new_label)
237 return (
238 torch.stack(support_data),
239 torch.tensor(support_labels),
240 torch.stack(query_data),
241 torch.tensor(query_labels)
242 )
244 def _estimate_class_difficulties(self) -> Dict[int, float]:
245 """
246 Estimate difficulty of each class based on intra-class variance.
248 FIXME RESEARCH ACCURACY ISSUES:
249 1. ARBITRARY DIFFICULTY METRIC: No research basis for using mean pairwise distance as difficulty
250 2. INEFFICIENT COMPUTATION: O(n²) complexity for pairwise distance calculation
251 3. MISSING ESTABLISHED METRICS: Should use research-validated difficulty measures
252 4. NO COMPARISON TO BASELINES: Not comparing to standard difficulty estimation methods
254 BETTER APPROACHES from research:
255 """
256 difficulties = {}
258 for class_id, indices in self.class_to_indices.items():
259 if len(indices) > 1:
260 class_data = self.data[indices]
262 # CURRENT (PROBLEMATIC): Arbitrary pairwise distance measure
263 flattened_data = class_data.view(len(class_data), -1)
264 distances = torch.cdist(flattened_data, flattened_data)
265 mean_distance = distances.sum() / (len(distances) ** 2 - len(distances))
266 difficulties[class_id] = mean_distance.item()
267 else:
268 difficulties[class_id] = 0.5 # Default medium difficulty
270 # Normalize difficulties to [0, 1]
271 if difficulties:
272 max_diff = max(difficulties.values())
273 min_diff = min(difficulties.values())
274 if max_diff > min_diff:
275 for class_id in difficulties:
276 difficulties[class_id] = (difficulties[class_id] - min_diff) / (max_diff - min_diff)
278 return difficulties
280 def _estimate_class_difficulties_research_accurate(self) -> Dict[int, float]:
281 """
282 Route to appropriate research-accurate difficulty estimation method based on configuration.
283 """
284 if self.config.difficulty_estimation_method == "silhouette":
285 return self._estimate_class_difficulty_silhouette()
286 elif self.config.difficulty_estimation_method == "entropy":
287 return self._estimate_class_difficulty_entropy()
288 elif self.config.difficulty_estimation_method == "knn":
289 return self._estimate_class_difficulty_knn()
290 else: # default to pairwise_distance
291 return self._estimate_class_difficulties()
293 def _estimate_class_difficulty_silhouette(self) -> Dict[int, float]:
294 """
295 FIXME SOLUTION 1: Use Silhouette Score for class difficulty estimation.
297 Based on "Silhouette: a graphical aid to the interpretation and validation of cluster analysis" (1987)
298 Silhouette score measures how well-separated classes are.
299 """
300 from sklearn.metrics import silhouette_samples
302 difficulties = {}
303 all_data = self.data.view(len(self.data), -1).numpy()
304 all_labels = self.labels.numpy()
306 # Compute silhouette scores for all samples
307 silhouette_scores = silhouette_samples(all_data, all_labels)
309 # Average silhouette score per class (lower = more difficult)
310 for class_id in self.unique_classes:
311 class_mask = all_labels == class_id
312 class_silhouette = silhouette_scores[class_mask].mean()
314 # Convert to difficulty (1 - silhouette, normalized to [0, 1])
315 difficulties[class_id] = 1.0 - (class_silhouette + 1.0) / 2.0
317 return difficulties
319 def _estimate_class_difficulty_entropy(self) -> Dict[int, float]:
320 """
321 FIXME SOLUTION 2: Use feature entropy for difficulty estimation.
323 Classes with higher feature entropy are typically more difficult.
324 Common approach in few-shot learning literature.
325 """
326 difficulties = {}
328 for class_id, indices in self.class_to_indices.items():
329 if len(indices) > 1:
330 class_data = self.data[indices]
332 # Compute feature-wise entropy
333 flattened_data = class_data.view(len(class_data), -1)
335 # Discretize features for entropy calculation
336 discretized = torch.floor(flattened_data * 10) / 10 # Simple binning
338 # Compute entropy for each feature dimension
339 entropies = []
340 for feature_dim in range(discretized.shape[1]):
341 feature_values = discretized[:, feature_dim]
342 unique_vals, counts = torch.unique(feature_values, return_counts=True)
343 probs = counts.float() / len(feature_values)
344 entropy = -torch.sum(probs * torch.log(probs + 1e-8))
345 entropies.append(entropy.item())
347 # Average entropy as difficulty measure
348 difficulties[class_id] = np.mean(entropies)
349 else:
350 difficulties[class_id] = 0.5
352 return difficulties
354 def _estimate_class_difficulty_knn(self) -> Dict[int, float]:
355 """
356 FIXME SOLUTION 3: Use k-NN classification accuracy for difficulty estimation.
358 Based on the intuition that harder classes have lower k-NN accuracy.
359 Well-established in machine learning literature.
360 """
361 from sklearn.neighbors import KNeighborsClassifier
362 from sklearn.model_selection import cross_val_score
364 difficulties = {}
366 # For each class, measure how well k-NN can distinguish it from others
367 for class_id in self.unique_classes:
368 # Create binary classification problem: current class vs all others
369 class_mask = self.labels == class_id
370 binary_labels = class_mask.long()
372 # Prepare data
373 X = self.data.view(len(self.data), -1).numpy()
374 y = binary_labels.numpy()
376 # k-NN classification
377 knn = KNeighborsClassifier(n_neighbors=5)
378 scores = cross_val_score(knn, X, y, cv=3, scoring='accuracy')
380 # Lower accuracy = higher difficulty
381 difficulties[class_id] = 1.0 - scores.mean()
383 return difficulties
385 def _apply_augmentation(self, data: torch.Tensor, strategy: str) -> torch.Tensor:
386 """Apply data augmentation strategies optimized for meta-learning."""
387 if strategy == "basic":
388 return self._basic_augmentation(data)
389 elif strategy == "advanced":
390 return self._advanced_augmentation(data)
391 else:
392 return data
394 def _basic_augmentation(self, data: torch.Tensor) -> torch.Tensor:
395 """Basic augmentation: random noise and small rotations."""
396 # Add random noise
397 noise_std = 0.01
398 noise = torch.randn_like(data) * noise_std
399 augmented = data + noise
401 return torch.clamp(augmented, 0, 1) # Assume data is normalized to [0, 1]
403 def _advanced_augmentation(self, data: torch.Tensor) -> torch.Tensor:
404 """Advanced augmentation with meta-learning specific techniques."""
405 # Meta-learning specific augmentation that preserves task structure
406 # while adding beneficial variance
408 # 1. Support set mixing (mix examples within the same class)
409 augmented = data.clone()
411 # 2. Add calibrated noise based on data statistics
412 data_std = data.std(dim=0, keepdim=True)
413 noise = torch.randn_like(data) * (data_std * 0.05)
414 augmented = augmented + noise
416 # 3. Random feature masking (for structured data)
417 if len(data.shape) > 2: # Multi-dimensional features
418 mask_prob = 0.1
419 mask = torch.rand_like(data) > mask_prob
420 augmented = augmented * mask
422 return torch.clamp(augmented, 0, 1)
424 def _compute_task_metadata(
425 self,
426 task_classes: List[int],
427 support_labels: torch.Tensor,
428 query_labels: torch.Tensor
429 ) -> Dict[str, Any]:
430 """Compute metadata for the sampled task."""
431 metadata = {
432 "n_way": len(task_classes),
433 "k_shot": self.config.k_shot,
434 "q_query": self.config.q_query,
435 "task_classes": task_classes,
436 "class_difficulties": [self.class_difficulties[c] for c in task_classes],
437 "avg_difficulty": np.mean([self.class_difficulties[c] for c in task_classes])
438 }
440 # Add class names if available
441 if self.class_names:
442 metadata["class_names"] = [self.class_names[c] for c in task_classes]
444 return metadata
447class TaskSampler(Sampler):
448 """
449 Advanced Task Sampler for meta-learning with curriculum learning support.
451 Key features not found in existing libraries:
452 1. Curriculum learning with difficulty progression
453 2. Balanced sampling across task types and difficulties
454 3. Anti-correlation sampling to ensure task diversity
455 4. Adaptive batch composition based on performance
456 """
458 def __init__(
459 self,
460 dataset: MetaLearningDataset,
461 batch_size: int = 16,
462 curriculum_learning: bool = True,
463 difficulty_schedule: str = "linear" # linear, exponential, adaptive
464 ):
465 """
466 Initialize Task Sampler.
468 Args:
469 dataset: MetaLearningDataset to sample from
470 batch_size: Number of tasks per batch
471 curriculum_learning: Whether to use curriculum learning
472 difficulty_schedule: How difficulty progresses over training
473 """
474 self.dataset = dataset
475 self.batch_size = batch_size
476 self.curriculum_learning = curriculum_learning
477 self.difficulty_schedule = difficulty_schedule
479 # Curriculum state
480 self.current_epoch = 0
481 self.total_epochs = 1000 # Will be updated during training
482 self.difficulty_level = 0.0 # 0.0 = easiest, 1.0 = hardest
484 # Performance tracking for adaptive curriculum
485 self.performance_history = []
487 logger.info(f"Initialized TaskSampler: batch_size={batch_size}, curriculum={curriculum_learning}")
489 def __iter__(self) -> Iterator[List[int]]:
490 """Generate batches of task indices."""
491 n = len(self.dataset)
493 # Generate task indices
494 indices = list(range(n))
496 # Curriculum learning: filter by difficulty
497 if self.curriculum_learning:
498 indices = self._apply_curriculum_filter(indices)
500 # Shuffle for randomness
501 random.shuffle(indices)
503 # Generate batches
504 for i in range(0, len(indices), self.batch_size):
505 batch_indices = indices[i:i + self.batch_size]
506 if len(batch_indices) == self.batch_size: # Only yield full batches
507 yield batch_indices
509 def __len__(self) -> int:
510 """Number of batches per epoch."""
511 effective_size = len(self.dataset)
512 if self.curriculum_learning:
513 # Account for curriculum filtering
514 effective_size = int(effective_size * min(1.0, 0.1 + 0.9 * self.difficulty_level))
515 return effective_size // self.batch_size
517 def update_epoch(self, epoch: int, total_epochs: int):
518 """Update curriculum state for new epoch."""
519 self.current_epoch = epoch
520 self.total_epochs = total_epochs
522 # Update difficulty level based on schedule
523 if self.difficulty_schedule == "linear":
524 self.difficulty_level = epoch / total_epochs
525 elif self.difficulty_schedule == "exponential":
526 self.difficulty_level = (np.exp(epoch / total_epochs) - 1) / (np.e - 1)
527 elif self.difficulty_schedule == "adaptive":
528 self.difficulty_level = self._adaptive_difficulty_schedule()
530 self.difficulty_level = np.clip(self.difficulty_level, 0.0, 1.0)
532 logger.debug(f"Epoch {epoch}: difficulty_level = {self.difficulty_level:.3f}")
534 def _apply_curriculum_filter(self, indices: List[int]) -> List[int]:
535 """Filter task indices based on current curriculum difficulty."""
536 # This is a simplified version - in practice would use actual task difficulties
537 # For now, include a fraction of tasks based on difficulty level
538 fraction_to_include = 0.1 + 0.9 * self.difficulty_level
539 num_to_include = int(len(indices) * fraction_to_include)
541 return indices[:num_to_include]
543 def _adaptive_difficulty_schedule(self) -> float:
544 """Compute adaptive difficulty based on recent performance."""
545 if len(self.performance_history) < 10:
546 # Not enough data, use linear schedule
547 return self.current_epoch / self.total_epochs
549 # Compute recent performance trend
550 recent_performance = self.performance_history[-10:]
551 performance_mean = np.mean(recent_performance)
552 performance_trend = np.mean(np.diff(recent_performance))
554 # Adapt difficulty based on performance
555 base_difficulty = self.current_epoch / self.total_epochs
557 if performance_mean > 0.8 and performance_trend > 0:
558 # High performance and improving - increase difficulty faster
559 adaptation = min(0.2, performance_trend * 5)
560 elif performance_mean < 0.6 and performance_trend < 0:
561 # Low performance and declining - slow down difficulty increase
562 adaptation = max(-0.1, performance_trend * 2)
563 else:
564 adaptation = 0
566 return np.clip(base_difficulty + adaptation, 0.0, 1.0)
568 def update_performance(self, accuracy: float):
569 """Update performance history for adaptive curriculum."""
570 self.performance_history.append(accuracy)
572 # Keep only recent history
573 if len(self.performance_history) > 100:
574 self.performance_history = self.performance_history[-100:]