Coverage for src / harnessutils / compaction / pruning.py: 89%

228 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-12 22:41 -0600

1"""Tier 2: Selective pruning of tool outputs. 

2 

3Removes old tool outputs while preserving conversation structure. 

4Cost: Cheap (~50ms), Latency: ~50ms. 

5""" 

6 

7import hashlib 

8import math 

9from dataclasses import dataclass 

10 

11from harnessutils.config import PruningConfig 

12from harnessutils.models.message import Message 

13from harnessutils.models.parts import ToolPart 

14from harnessutils.tokens.exact import count_tokens_fast 

15 

16 

17@dataclass 

18class PruningResult: 

19 """Result of pruning operation with detailed token tracking.""" 

20 

21 pruned: int # Total outputs pruned 

22 tokens_saved: int # Total tokens saved 

23 tokens_before: int = 0 # Token count before pruning 

24 tokens_after: int = 0 # Token count after pruning 

25 duplicates_pruned: int = 0 # Outputs pruned due to duplication 

26 importance_pruned: int = 0 # Outputs pruned due to low importance 

27 duplicate_tokens_saved: int = 0 # Tokens saved from deduplication 

28 importance_tokens_saved: int = 0 # Tokens saved from importance pruning 

29 

30 def to_dict(self) -> dict[str, int]: 

31 """Convert to dictionary for reporting. 

32 

33 Returns: 

34 Dictionary with all pruning metrics 

35 """ 

36 return { 

37 "pruned": self.pruned, 

38 "tokens_saved": self.tokens_saved, 

39 "tokens_before": self.tokens_before, 

40 "tokens_after": self.tokens_after, 

41 "duplicates_pruned": self.duplicates_pruned, 

42 "importance_pruned": self.importance_pruned, 

43 "duplicate_tokens_saved": self.duplicate_tokens_saved, 

44 "importance_tokens_saved": self.importance_tokens_saved, 

45 "reduction_percent": round( 

46 (self.tokens_saved / self.tokens_before * 100) if self.tokens_before > 0 else 0, 

47 1, 

48 ), 

49 } 

50 

51 def __str__(self) -> str: 

52 """Human-readable summary of pruning results.""" 

53 if self.pruned == 0: 

54 return "No pruning needed" 

55 

56 lines = [ 

57 f"Pruned {self.pruned} outputs, saved {self.tokens_saved:,} tokens", 

58 f" Before: {self.tokens_before:,} tokens", 

59 f" After: {self.tokens_after:,} tokens", 

60 f" Reduction: {self.to_dict()['reduction_percent']}%", 

61 ] 

62 

63 if self.duplicates_pruned > 0: 

64 lines.append( 

65 f" Duplicates: {self.duplicates_pruned} removed " 

66 f"({self.duplicate_tokens_saved:,} tokens)" 

67 ) 

68 

69 if self.importance_pruned > 0: 

70 lines.append( 

71 f" Low importance: {self.importance_pruned} removed " 

72 f"({self.importance_tokens_saved:,} tokens)" 

73 ) 

74 

75 return "\n".join(lines) 

76 

77 

78@dataclass 

79class OutputImportance: 

80 """Importance score components for a tool output.""" 

81 

82 recency_score: float # Exponential decay based on age 

83 size_penalty: float # Penalty for large outputs 

84 semantic_score: float # Content-based importance (errors, warnings) 

85 tool_priority: float # Tool type importance 

86 token_count: int # Actual token count 

87 

88 @property 

89 def total_score(self) -> float: 

90 """Calculate weighted total importance score. 

91 

92 Higher score = more important = keep longer. 

93 Lower score = less important = prune first. 

94 """ 

95 return ( 

96 self.recency_score 

97 + self.size_penalty 

98 + self.semantic_score 

99 + self.tool_priority 

100 ) 

101 

102 

103def generate_shingles(text: str, n: int = 5) -> set[str]: 

104 """Generate word-level n-grams (shingles) for similarity detection. 

105 

106 Args: 

107 text: Text to generate shingles from 

108 n: Size of n-grams (default: 5 words) 

109 

110 Returns: 

111 Set of n-gram strings 

112 """ 

113 # Normalize text 

114 words = text.lower().split() 

115 

116 if len(words) < n: 

117 # If text too short, use the whole thing 

118 return {" ".join(words)} if words else set() 

119 

120 # Generate n-grams 

121 shingles = set() 

122 for i in range(len(words) - n + 1): 

123 shingle = " ".join(words[i : i + n]) 

124 shingles.add(shingle) 

125 

126 return shingles 

127 

128 

129def jaccard_similarity(set1: set[str], set2: set[str]) -> float: 

130 """Calculate Jaccard similarity between two sets. 

131 

132 Jaccard similarity = |intersection| / |union| 

133 Returns value between 0.0 (no overlap) and 1.0 (identical). 

134 

135 Args: 

136 set1: First set 

137 set2: Second set 

138 

139 Returns: 

140 Similarity score [0.0, 1.0] 

141 """ 

142 if not set1 or not set2: 

143 return 0.0 

144 

145 intersection = len(set1 & set2) 

146 union = len(set1 | set2) 

147 

148 if union == 0: 

149 return 0.0 

150 

151 return intersection / union 

152 

153 

154def compute_content_hash(text: str) -> str: 

155 """Compute fast hash of text content for exact duplicate detection. 

156 

157 Args: 

158 text: Text to hash 

159 

160 Returns: 

161 MD5 hash hex string 

162 """ 

163 return hashlib.md5(text.encode("utf-8")).hexdigest() 

164 

165 

166def find_duplicate_output( 

167 part: ToolPart, 

168 recent_parts: list[tuple[ToolPart, set[str]]], 

169 similarity_threshold: float = 0.8, 

170) -> ToolPart | None: 

171 """Find if output is duplicate of a recent output. 

172 

173 Args: 

174 part: Tool part to check for duplication 

175 recent_parts: List of (part, shingles) tuples to compare against 

176 similarity_threshold: Minimum similarity to consider duplicate (default: 0.8) 

177 

178 Returns: 

179 The duplicate part if found, None otherwise 

180 """ 

181 # Only compare against same tool type 

182 # Different tools are unlikely to produce identical outputs 

183 # This prevents false positives across tool types 

184 same_tool_parts = [ 

185 (p, shingles) for p, shingles in recent_parts if p.tool == part.tool 

186 ] 

187 

188 if not same_tool_parts: 

189 return None 

190 

191 # Fast path: exact duplicate check (within same tool type) 

192 current_hash = compute_content_hash(part.state.output) 

193 

194 for recent_part, _ in same_tool_parts: 

195 if compute_content_hash(recent_part.state.output) == current_hash: 

196 return recent_part # Exact duplicate 

197 

198 # Generate shingles for current part 

199 current_shingles = generate_shingles(part.state.output) 

200 

201 # Check similarity against recent parts of same tool 

202 for recent_part, recent_shingles in same_tool_parts: 

203 similarity = jaccard_similarity(current_shingles, recent_shingles) 

204 

205 if similarity >= similarity_threshold: 

206 return recent_part # Found similar output 

207 

208 return None 

209 

210 

211def calculate_context_tokens(messages: list[Message]) -> int: 

212 """Calculate total token count for all tool outputs in conversation. 

213 

214 Args: 

215 messages: All conversation messages 

216 

217 Returns: 

218 Total token count 

219 """ 

220 total = 0 

221 for msg in messages: 

222 for part in msg.parts: 

223 if isinstance(part, ToolPart) and part.state.status == "completed": 

224 if part.state.output: # Only count non-empty outputs 

225 total += count_tokens_fast(part.state.output) 

226 return total 

227 

228 

229def calculate_turn_age(messages: list[Message], target_msg: Message) -> int: 

230 """Calculate how many turns ago a message was created. 

231 

232 Args: 

233 messages: All conversation messages 

234 target_msg: Message to calculate age for 

235 

236 Returns: 

237 Number of user turns since this message (0 = current turn) 

238 """ 

239 turn_count = 0 

240 for msg in reversed(messages): 

241 if msg.id == target_msg.id: 

242 return turn_count 

243 if msg.role == "user": 

244 turn_count += 1 

245 return turn_count 

246 

247 

248def score_tool_output( 

249 part: ToolPart, 

250 message: Message, 

251 messages: list[Message], 

252 config: PruningConfig, 

253) -> OutputImportance: 

254 """Calculate importance score for a tool output. 

255 

256 Args: 

257 part: Tool part to score 

258 message: Message containing this part 

259 messages: All conversation messages 

260 config: Pruning configuration with scoring weights 

261 

262 Returns: 

263 OutputImportance with all score components 

264 """ 

265 token_count = count_tokens_fast(part.state.output) 

266 

267 # 1. Recency score (exponential decay) 

268 age = calculate_turn_age(messages, message) 

269 recency = math.exp(-age * config.recency_decay) * 100 * config.recency_weight 

270 

271 # 2. Size penalty (prefer removing large outputs) 

272 size_penalty = math.log(token_count + 1) * 10 * config.size_weight 

273 

274 # 3. Semantic importance 

275 semantic = 0.0 

276 output_lower = part.state.output.lower() 

277 

278 # Check for errors 

279 if any( 

280 keyword in output_lower 

281 for keyword in ["error", "exception", "traceback", "failed"] 

282 ): 

283 semantic += config.error_boost 

284 

285 # Check for warnings 

286 if "warning" in output_lower or "warn" in output_lower: 

287 semantic += config.warning_boost 

288 

289 # Check if user explicitly requested this 

290 if part.state.metadata and part.state.metadata.get("user_requested"): 

291 semantic += config.user_requested_boost 

292 

293 semantic *= config.semantic_weight 

294 

295 # 4. Tool type priority 

296 tool_priority = config.tool_importance.get(part.tool, 50.0) * config.tool_priority_weight 

297 

298 return OutputImportance( 

299 recency_score=recency, 

300 size_penalty=size_penalty, 

301 semantic_score=semantic, 

302 tool_priority=tool_priority, 

303 token_count=token_count, 

304 ) 

305 

306 

307def prune_tool_outputs( 

308 messages: list[Message], 

309 config: PruningConfig, 

310) -> PruningResult: 

311 """Prune tool outputs from conversation history. 

312 

313 Selectively removes old tool outputs while preserving: 

314 - Tool call metadata (name, input, title, timing) 

315 - Recent outputs (within protection window) 

316 - Protected tool outputs 

317 - Last N turns 

318 

319 Uses exact token counting via tiktoken for accurate pruning decisions. 

320 

321 First runs duplicate detection (if enabled), then applies either 

322 importance-based or FIFO pruning strategy. 

323 

324 Args: 

325 messages: Conversation messages (newest first recommended) 

326 config: Pruning configuration 

327 

328 Returns: 

329 PruningResult with detailed token tracking and breakdown 

330 """ 

331 # Calculate token usage before pruning 

332 tokens_before = calculate_context_tokens(messages) 

333 

334 # Phase 1: Duplicate detection (runs regardless of scoring strategy) 

335 duplicate_result = PruningResult(pruned=0, tokens_saved=0) 

336 if config.detect_duplicates: 

337 duplicate_result = _detect_and_prune_duplicates(messages, config) 

338 

339 # Phase 2: Importance-based or FIFO pruning 

340 if config.use_importance_scoring: 

341 pruning_result = _prune_with_importance_scoring_only(messages, config) 

342 else: 

343 pruning_result = _prune_simple(messages, config) 

344 

345 # Calculate token usage after pruning 

346 tokens_after = calculate_context_tokens(messages) 

347 

348 # Combine results with detailed tracking 

349 return PruningResult( 

350 pruned=duplicate_result.pruned + pruning_result.pruned, 

351 tokens_saved=duplicate_result.tokens_saved + pruning_result.tokens_saved, 

352 tokens_before=tokens_before, 

353 tokens_after=tokens_after, 

354 duplicates_pruned=duplicate_result.pruned, 

355 importance_pruned=pruning_result.pruned, 

356 duplicate_tokens_saved=duplicate_result.tokens_saved, 

357 importance_tokens_saved=pruning_result.tokens_saved, 

358 ) 

359 

360 

361def _prune_simple( 

362 messages: list[Message], 

363 config: PruningConfig, 

364) -> PruningResult: 

365 """Simple FIFO pruning (original algorithm). 

366 

367 Prunes oldest outputs first when token budget exceeded. 

368 """ 

369 total_tokens = 0 

370 prunable_tokens = 0 

371 to_prune: list[tuple[Message, ToolPart]] = [] 

372 turns_skipped = 0 

373 

374 for msg in reversed(messages): 

375 if msg.role == "user": 

376 turns_skipped += 1 

377 

378 if turns_skipped < config.protect_turns: 

379 continue 

380 

381 if msg.summary: 

382 break 

383 

384 for part in msg.parts: 

385 if not isinstance(part, ToolPart): 

386 continue 

387 

388 if part.state.status != "completed": 

389 continue 

390 

391 if part.tool in config.protected_tools: 

392 continue 

393 

394 if part.state.time and part.state.time.compacted: 

395 continue 

396 

397 token_count = count_tokens_fast(part.state.output) 

398 total_tokens += token_count 

399 

400 if total_tokens > config.prune_protect: 

401 prunable_tokens += token_count 

402 to_prune.append((msg, part)) 

403 

404 if prunable_tokens > config.prune_minimum: 

405 for msg, part in to_prune: 

406 part.state.output = "" 

407 part.state.attachments = [] 

408 if part.state.time: 

409 import time 

410 

411 part.state.time.compacted = int(time.time() * 1000) 

412 

413 return PruningResult(pruned=len(to_prune), tokens_saved=prunable_tokens) 

414 

415 return PruningResult(pruned=0, tokens_saved=0) 

416 

417 

418def _prune_with_importance_scoring_only( 

419 messages: list[Message], 

420 config: PruningConfig, 

421) -> PruningResult: 

422 """Smart pruning using importance scoring. 

423 

424 Scores all outputs by importance and prunes lowest-value first. 

425 Preserves critical outputs (errors, warnings, user-requested). 

426 

427 Note: Duplicate detection happens separately before this function is called. 

428 """ 

429 # Collect all prunable outputs with scores 

430 scored_outputs: list[tuple[Message, ToolPart, OutputImportance]] = [] 

431 turns_skipped = 0 

432 

433 for msg in reversed(messages): 

434 if msg.role == "user": 

435 turns_skipped += 1 

436 

437 if turns_skipped < config.protect_turns: 

438 continue 

439 

440 if msg.summary: 

441 break 

442 

443 for part in msg.parts: 

444 if not isinstance(part, ToolPart): 

445 continue 

446 

447 if part.state.status != "completed": 

448 continue 

449 

450 if part.tool in config.protected_tools: 

451 continue 

452 

453 if part.state.time and part.state.time.compacted: 

454 continue 

455 

456 # Score this output 

457 importance = score_tool_output(part, msg, messages, config) 

458 scored_outputs.append((msg, part, importance)) 

459 

460 # Calculate total tokens 

461 total_tokens = sum(imp.token_count for _, _, imp in scored_outputs) 

462 

463 # If under budget, nothing to prune 

464 if total_tokens <= config.prune_protect: 

465 return PruningResult(pruned=0, tokens_saved=0) 

466 

467 # Sort by importance (lowest score first = prune first) 

468 scored_outputs.sort(key=lambda x: x[2].total_score) 

469 

470 # Prune until we're under budget 

471 pruned_outputs: list[tuple[Message, ToolPart, int]] = [] 

472 current_tokens = total_tokens 

473 

474 for msg, part, importance in scored_outputs: 

475 if current_tokens <= config.prune_protect: 

476 break 

477 

478 # Prune this output 

479 pruned_outputs.append((msg, part, importance.token_count)) 

480 current_tokens -= importance.token_count 

481 

482 # Only prune if savings meet minimum threshold 

483 tokens_saved = sum(tokens for _, _, tokens in pruned_outputs) 

484 

485 if tokens_saved >= config.prune_minimum: 

486 for msg, part, _ in pruned_outputs: 

487 part.state.output = "" 

488 part.state.attachments = [] 

489 if part.state.time: 

490 import time 

491 

492 part.state.time.compacted = int(time.time() * 1000) 

493 

494 return PruningResult(pruned=len(pruned_outputs), tokens_saved=tokens_saved) 

495 

496 return PruningResult(pruned=0, tokens_saved=0) 

497 

498 

499def _detect_and_prune_duplicates( 

500 messages: list[Message], 

501 config: PruningConfig, 

502) -> PruningResult: 

503 """Detect and prune duplicate tool outputs. 

504 

505 Uses similarity hashing (shingles + Jaccard) to find near-duplicates. 

506 Aggressively prunes older duplicates while keeping the most recent. 

507 

508 Args: 

509 messages: Conversation messages 

510 config: Pruning configuration 

511 

512 Returns: 

513 PruningResult with duplicates pruned 

514 """ 

515 # Build index of recent outputs with their shingles 

516 recent_outputs: list[tuple[ToolPart, set[str], Message]] = [] 

517 duplicates_to_prune: list[tuple[Message, ToolPart, int]] = [] 

518 turns_skipped = 0 

519 

520 for msg in reversed(messages): # Newest first 

521 if msg.role == "user": 

522 turns_skipped += 1 

523 

524 if turns_skipped < config.protect_turns: 

525 continue 

526 

527 if msg.summary: 

528 break 

529 

530 for part in msg.parts: 

531 if not isinstance(part, ToolPart): 

532 continue 

533 

534 if part.state.status != "completed": 

535 continue 

536 

537 if part.tool in config.protected_tools: 

538 continue 

539 

540 if part.state.time and part.state.time.compacted: 

541 continue 

542 

543 # Check if this is a duplicate 

544 if len(recent_outputs) > 0: 

545 # Only check against recent outputs (limited lookback) 

546 lookback = recent_outputs[-config.duplicate_lookback :] 

547 duplicate_of = find_duplicate_output( 

548 part, 

549 [(p, shingles) for p, shingles, _ in lookback], 

550 config.similarity_threshold, 

551 ) 

552 

553 if duplicate_of: 

554 # This is a duplicate - mark for pruning 

555 token_count = count_tokens_fast(part.state.output) 

556 duplicates_to_prune.append((msg, part, token_count)) 

557 continue # Don't add to recent_outputs 

558 

559 # Not a duplicate - add to index 

560 shingles = generate_shingles(part.state.output) 

561 recent_outputs.append((part, shingles, msg)) 

562 

563 # Prune duplicates 

564 tokens_saved = 0 

565 for msg, part, token_count in duplicates_to_prune: 

566 part.state.output = "" 

567 part.state.attachments = [] 

568 if part.state.time: 

569 import time as time_module 

570 

571 part.state.time.compacted = int(time_module.time() * 1000) 

572 

573 tokens_saved += token_count 

574 

575 return PruningResult( 

576 pruned=len(duplicates_to_prune), 

577 tokens_saved=tokens_saved, 

578 )