Coverage for src/meta_learning/meta_learning_modules/few_shot_modules/utilities.py: 9%

145 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-03 12:35 +0900

1""" 

2Few-Shot Learning Utilities 

3========================== 

4 

5Utility functions for few-shot learning including factory functions, 

6evaluation utilities, and helper functions. 

7Extracted from the original monolithic few_shot_learning.py. 

8""" 

9 

10import torch 

11import torch.nn as nn 

12import torch.nn.functional as F 

13from typing import Dict, List, Tuple, Optional, Any 

14import numpy as np 

15import logging 

16 

17from .configurations import PrototypicalConfig 

18from .core_networks import PrototypicalNetworks 

19 

20logger = logging.getLogger(__name__) 

21 

22 

23def create_prototypical_network( 

24 backbone: nn.Module, 

25 variant: str = "research_accurate", 

26 config: PrototypicalConfig = None 

27) -> PrototypicalNetworks: 

28 """ 

29 Factory function to create Prototypical Networks with specific configuration. 

30  

31 Args: 

32 backbone: Feature extraction backbone network 

33 variant: Implementation variant ('research_accurate', 'simple', 'enhanced', 'original') 

34 config: Optional custom configuration 

35  

36 Returns: 

37 Configured PrototypicalNetworks instance 

38 """ 

39 if config is None: 

40 config = PrototypicalConfig() 

41 

42 # Set variant-specific configuration 

43 if hasattr(config, 'protonet_variant'): 

44 config.protonet_variant = variant 

45 

46 # Configure based on variant 

47 if variant == "research_accurate": 

48 # Pure research-accurate implementation 

49 if hasattr(config, 'use_squared_euclidean'): 

50 config.use_squared_euclidean = True 

51 if hasattr(config, 'prototype_method'): 

52 config.prototype_method = "mean" 

53 if hasattr(config, 'enable_research_extensions'): 

54 config.enable_research_extensions = False 

55 config.multi_scale_features = False 

56 config.adaptive_prototypes = False 

57 if hasattr(config, 'uncertainty_estimation'): 

58 config.uncertainty_estimation = False 

59 

60 elif variant == "simple": 

61 # Simplified educational version 

62 config.multi_scale_features = False 

63 config.adaptive_prototypes = False 

64 if hasattr(config, 'uncertainty_estimation'): 

65 config.uncertainty_estimation = False 

66 if hasattr(config, 'enable_research_extensions'): 

67 config.enable_research_extensions = False 

68 

69 elif variant == "enhanced": 

70 # All extensions enabled 

71 config.multi_scale_features = True 

72 config.adaptive_prototypes = True 

73 if hasattr(config, 'uncertainty_estimation'): 

74 config.uncertainty_estimation = True 

75 if hasattr(config, 'enable_research_extensions'): 

76 config.enable_research_extensions = True 

77 

78 return PrototypicalNetworks(backbone, config) 

79 

80 

81def compare_with_learn2learn_protonet(): 

82 """ 

83 Comparison with learn2learn's Prototypical Networks implementation. 

84  

85 learn2learn approach: 

86 ```python 

87 import learn2learn as l2l 

88  

89 # Create prototypical network head 

90 head = l2l.algorithms.Lightning( 

91 l2l.utils.ProtoLightning, 

92 ways=5, 

93 shots=5,  

94 model=backbone 

95 ) 

96  

97 # Standard training loop 

98 for batch in dataloader: 

99 support, query = batch 

100 loss = head.forward(support, query) 

101 loss.backward() 

102 optimizer.step() 

103 ``` 

104  

105 Key differences from our implementation: 

106 1. learn2learn uses Lightning framework for training automation 

107 2. They provide built-in data loaders for standard benchmarks 

108 3. Our implementation is more educational/research-focused 

109 4. learn2learn handles meta-batch processing automatically 

110 """ 

111 comparison_info = { 

112 "learn2learn_advantages": [ 

113 "Lightning framework integration", 

114 "Built-in benchmark data loaders", 

115 "Automatic meta-batch processing", 

116 "Production-ready training loops" 

117 ], 

118 "our_advantages": [ 

119 "Educational and research-focused", 

120 "Research-accurate implementations", 

121 "Configurable variants", 

122 "Extensive documentation and citations", 

123 "Advanced extensions with proper attribution" 

124 ], 

125 "use_cases": { 

126 "learn2learn": "Production systems, quick prototyping", 

127 "our_implementation": "Research, education, algorithm understanding" 

128 } 

129 } 

130 

131 return comparison_info 

132 

133 

134def evaluate_on_standard_benchmarks(model, dataset_name="omniglot", episodes=600): 

135 """ 

136 Standard few-shot evaluation following research protocols. 

137  

138 Based on standard evaluation in meta-learning literature: 

139 - Omniglot: 20-way 1-shot and 5-shot 

140 - miniImageNet: 5-way 1-shot and 5-shot  

141 - tieredImageNet: 5-way 1-shot and 5-shot 

142  

143 Returns confidence intervals over specified episodes (standard: 600). 

144  

145 Args: 

146 model: Few-shot learning model 

147 dataset_name: Name of benchmark dataset 

148 episodes: Number of evaluation episodes 

149  

150 Returns: 

151 Dictionary with mean accuracy and confidence interval 

152 """ 

153 accuracies = [] 

154 

155 for episode in range(episodes): 

156 try: 

157 # Sample episode (N-way K-shot) 

158 support_x, support_y, query_x, query_y = sample_episode(dataset_name) 

159 

160 # Forward pass 

161 logits = model(support_x, support_y, query_x) 

162 if isinstance(logits, dict): 

163 logits = logits['logits'] 

164 

165 predictions = logits.argmax(dim=1) 

166 

167 # Compute accuracy 

168 accuracy = (predictions == query_y).float().mean() 

169 accuracies.append(accuracy.item()) 

170 

171 except Exception as e: 

172 logger.warning(f"Episode {episode} failed: {e}") 

173 continue 

174 

175 if len(accuracies) == 0: 

176 return {"mean_accuracy": 0.0, "confidence_interval": 0.0, "episodes": 0} 

177 

178 # Compute 95% confidence interval 

179 mean_acc = np.mean(accuracies) 

180 std_acc = np.std(accuracies) 

181 ci = 1.96 * std_acc / np.sqrt(len(accuracies)) # 95% CI 

182 

183 return { 

184 "mean_accuracy": mean_acc, 

185 "confidence_interval": ci, 

186 "std_accuracy": std_acc, 

187 "episodes": len(accuracies), 

188 "raw_accuracies": accuracies 

189 } 

190 

191 

192def sample_episode(dataset_name: str, n_way: int = 5, n_support: int = 5, n_query: int = 15): 

193 """ 

194 Sample a few-shot episode from the specified dataset. 

195  

196 This is a placeholder implementation for demonstration. 

197 In practice, you would integrate with actual dataset loaders. 

198  

199 Args: 

200 dataset_name: Name of the dataset 

201 n_way: Number of classes per episode 

202 n_support: Number of support examples per class 

203 n_query: Number of query examples per class 

204  

205 Returns: 

206 Tuple of (support_x, support_y, query_x, query_y) 

207 """ 

208 # Placeholder implementation - replace with actual dataset loading 

209 if dataset_name == "omniglot": 

210 input_size = (1, 28, 28) 

211 n_way = 20 # Standard for Omniglot 

212 elif dataset_name in ["miniImageNet", "tieredImageNet"]: 

213 input_size = (3, 84, 84) 

214 n_way = 5 # Standard for ImageNet variants 

215 else: 

216 input_size = (3, 32, 32) # Default 

217 

218 # Generate synthetic data for demonstration 

219 support_x = torch.randn(n_way * n_support, *input_size) 

220 support_y = torch.repeat_interleave(torch.arange(n_way), n_support) 

221 

222 query_x = torch.randn(n_way * n_query, *input_size) 

223 query_y = torch.repeat_interleave(torch.arange(n_way), n_query) 

224 

225 return support_x, support_y, query_x, query_y 

226 

227 

228def euclidean_distance_squared(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 

229 """ 

230 Squared Euclidean distance as in Snell et al. (2017) Equation 1. 

231  

232 Args: 

233 x: Query embeddings [n_query, embedding_dim] 

234 y: Prototype embeddings [n_prototypes, embedding_dim] 

235  

236 Returns: 

237 Squared distances [n_query, n_prototypes] 

238 """ 

239 # Expand for broadcasting 

240 x_expanded = x.unsqueeze(1) # [n_query, 1, embedding_dim]  

241 y_expanded = y.unsqueeze(0) # [1, n_prototypes, embedding_dim] 

242 

243 # Compute squared Euclidean distance for gradient stability 

244 return torch.sum((x_expanded - y_expanded)**2, dim=-1) 

245 

246 

247def compute_prototype_statistics(prototypes: torch.Tensor, support_features: torch.Tensor, 

248 support_labels: torch.Tensor) -> Dict[str, float]: 

249 """ 

250 Compute statistics about learned prototypes for analysis. 

251  

252 Args: 

253 prototypes: Class prototypes [n_classes, embedding_dim] 

254 support_features: Support set features [n_support, embedding_dim]  

255 support_labels: Support set labels [n_support] 

256  

257 Returns: 

258 Dictionary with prototype statistics 

259 """ 

260 stats = {} 

261 

262 # Inter-prototype distances 

263 proto_distances = torch.cdist(prototypes, prototypes, p=2) 

264 # Remove diagonal (self-distances) 

265 mask = ~torch.eye(len(prototypes), dtype=bool) 

266 inter_distances = proto_distances[mask] 

267 

268 stats['mean_inter_prototype_distance'] = inter_distances.mean().item() 

269 stats['std_inter_prototype_distance'] = inter_distances.std().item() 

270 stats['min_inter_prototype_distance'] = inter_distances.min().item() 

271 stats['max_inter_prototype_distance'] = inter_distances.max().item() 

272 

273 # Intra-class distances (support examples to their prototype) 

274 intra_distances = [] 

275 for class_idx in torch.unique(support_labels): 

276 class_mask = support_labels == class_idx 

277 class_features = support_features[class_mask] 

278 class_prototype = prototypes[class_idx] 

279 

280 # Distances from class examples to prototype 

281 distances = torch.norm(class_features - class_prototype, p=2, dim=1) 

282 intra_distances.append(distances) 

283 

284 all_intra = torch.cat(intra_distances) 

285 stats['mean_intra_class_distance'] = all_intra.mean().item() 

286 stats['std_intra_class_distance'] = all_intra.std().item() 

287 

288 # Prototype quality metric (higher is better separation) 

289 separation_ratio = stats['mean_inter_prototype_distance'] / (stats['mean_intra_class_distance'] + 1e-8) 

290 stats['prototype_separation_ratio'] = separation_ratio 

291 

292 return stats 

293 

294 

295def analyze_few_shot_performance(model, test_episodes: int = 100, n_way: int = 5, 

296 n_support: int = 5, n_query: int = 15) -> Dict[str, Any]: 

297 """ 

298 Comprehensive analysis of few-shot learning performance. 

299  

300 Args: 

301 model: Few-shot learning model 

302 test_episodes: Number of test episodes 

303 n_way: Number of classes per episode 

304 n_support: Number of support examples per class 

305 n_query: Number of query examples per class 

306  

307 Returns: 

308 Comprehensive performance analysis 

309 """ 

310 model.eval() 

311 

312 episode_accuracies = [] 

313 prototype_stats_list = [] 

314 confidence_scores = [] 

315 

316 with torch.no_grad(): 

317 for episode in range(test_episodes): 

318 # Sample episode 

319 support_x, support_y, query_x, query_y = sample_episode( 

320 "synthetic", n_way, n_support, n_query 

321 ) 

322 

323 try: 

324 # Forward pass 

325 result = model(support_x, support_y, query_x) 

326 if isinstance(result, dict): 

327 logits = result['logits'] 

328 prototypes = result.get('prototypes') 

329 else: 

330 logits = result 

331 prototypes = None 

332 

333 # Compute accuracy 

334 predictions = logits.argmax(dim=1) 

335 accuracy = (predictions == query_y).float().mean().item() 

336 episode_accuracies.append(accuracy) 

337 

338 # Analyze prototypes if available 

339 if prototypes is not None: 

340 support_features = model.backbone(support_x) 

341 proto_stats = compute_prototype_statistics( 

342 prototypes, support_features, support_y 

343 ) 

344 prototype_stats_list.append(proto_stats) 

345 

346 # Analyze confidence 

347 probs = F.softmax(logits, dim=-1) 

348 max_probs = probs.max(dim=-1)[0] 

349 confidence_scores.extend(max_probs.tolist()) 

350 

351 except Exception as e: 

352 logger.warning(f"Episode {episode} analysis failed: {e}") 

353 continue 

354 

355 # Aggregate results 

356 analysis = { 

357 'accuracy_stats': { 

358 'mean': np.mean(episode_accuracies), 

359 'std': np.std(episode_accuracies), 

360 'min': np.min(episode_accuracies), 

361 'max': np.max(episode_accuracies), 

362 'episodes': len(episode_accuracies) 

363 }, 

364 'confidence_stats': { 

365 'mean': np.mean(confidence_scores), 

366 'std': np.std(confidence_scores), 

367 'median': np.median(confidence_scores) 

368 } if confidence_scores else None 

369 } 

370 

371 # Prototype analysis 

372 if prototype_stats_list: 

373 proto_analysis = {} 

374 for key in prototype_stats_list[0].keys(): 

375 values = [stats[key] for stats in prototype_stats_list] 

376 proto_analysis[key] = { 

377 'mean': np.mean(values), 

378 'std': np.std(values) 

379 } 

380 analysis['prototype_stats'] = proto_analysis 

381 

382 return analysis 

383 

384 

385def create_backbone_network(architecture: str = "conv4", input_channels: int = 3, 

386 embedding_dim: int = 512) -> nn.Module: 

387 """ 

388 Create a backbone network for few-shot learning. 

389  

390 Args: 

391 architecture: Backbone architecture ('conv4', 'resnet', 'simple') 

392 input_channels: Number of input channels 

393 embedding_dim: Output embedding dimension 

394  

395 Returns: 

396 Backbone network 

397 """ 

398 if architecture == "conv4": 

399 # Standard 4-layer CNN backbone used in few-shot learning 

400 backbone = nn.Sequential( 

401 # Layer 1 

402 nn.Conv2d(input_channels, 64, 3, padding=1), 

403 nn.BatchNorm2d(64), 

404 nn.ReLU(inplace=True), 

405 nn.MaxPool2d(2), 

406 

407 # Layer 2 

408 nn.Conv2d(64, 64, 3, padding=1), 

409 nn.BatchNorm2d(64), 

410 nn.ReLU(inplace=True), 

411 nn.MaxPool2d(2), 

412 

413 # Layer 3 

414 nn.Conv2d(64, 64, 3, padding=1), 

415 nn.BatchNorm2d(64), 

416 nn.ReLU(inplace=True), 

417 nn.MaxPool2d(2), 

418 

419 # Layer 4 

420 nn.Conv2d(64, 64, 3, padding=1), 

421 nn.BatchNorm2d(64), 

422 nn.ReLU(inplace=True), 

423 nn.MaxPool2d(2), 

424 

425 # Global average pooling 

426 nn.AdaptiveAvgPool2d(1), 

427 nn.Flatten(), 

428 

429 # Final projection to embedding dimension 

430 nn.Linear(64, embedding_dim) 

431 ) 

432 

433 elif architecture == "simple": 

434 # Simple backbone for educational purposes 

435 backbone = nn.Sequential( 

436 nn.Conv2d(input_channels, 32, 3, padding=1), 

437 nn.ReLU(inplace=True), 

438 nn.MaxPool2d(2), 

439 nn.Conv2d(32, 64, 3, padding=1), 

440 nn.ReLU(inplace=True), 

441 nn.MaxPool2d(2), 

442 nn.AdaptiveAvgPool2d(1), 

443 nn.Flatten(), 

444 nn.Linear(64, embedding_dim) 

445 ) 

446 

447 else: 

448 raise ValueError(f"Unknown backbone architecture: {architecture}") 

449 

450 return backbone