Coverage for src / harnessutils / compaction / pruning.py: 89%
228 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
1"""Tier 2: Selective pruning of tool outputs.
3Removes old tool outputs while preserving conversation structure.
4Cost: Cheap (~50ms), Latency: ~50ms.
5"""
7import hashlib
8import math
9from dataclasses import dataclass
11from harnessutils.config import PruningConfig
12from harnessutils.models.message import Message
13from harnessutils.models.parts import ToolPart
14from harnessutils.tokens.exact import count_tokens_fast
17@dataclass
18class PruningResult:
19 """Result of pruning operation with detailed token tracking."""
21 pruned: int # Total outputs pruned
22 tokens_saved: int # Total tokens saved
23 tokens_before: int = 0 # Token count before pruning
24 tokens_after: int = 0 # Token count after pruning
25 duplicates_pruned: int = 0 # Outputs pruned due to duplication
26 importance_pruned: int = 0 # Outputs pruned due to low importance
27 duplicate_tokens_saved: int = 0 # Tokens saved from deduplication
28 importance_tokens_saved: int = 0 # Tokens saved from importance pruning
30 def to_dict(self) -> dict[str, int]:
31 """Convert to dictionary for reporting.
33 Returns:
34 Dictionary with all pruning metrics
35 """
36 return {
37 "pruned": self.pruned,
38 "tokens_saved": self.tokens_saved,
39 "tokens_before": self.tokens_before,
40 "tokens_after": self.tokens_after,
41 "duplicates_pruned": self.duplicates_pruned,
42 "importance_pruned": self.importance_pruned,
43 "duplicate_tokens_saved": self.duplicate_tokens_saved,
44 "importance_tokens_saved": self.importance_tokens_saved,
45 "reduction_percent": round(
46 (self.tokens_saved / self.tokens_before * 100) if self.tokens_before > 0 else 0,
47 1,
48 ),
49 }
51 def __str__(self) -> str:
52 """Human-readable summary of pruning results."""
53 if self.pruned == 0:
54 return "No pruning needed"
56 lines = [
57 f"Pruned {self.pruned} outputs, saved {self.tokens_saved:,} tokens",
58 f" Before: {self.tokens_before:,} tokens",
59 f" After: {self.tokens_after:,} tokens",
60 f" Reduction: {self.to_dict()['reduction_percent']}%",
61 ]
63 if self.duplicates_pruned > 0:
64 lines.append(
65 f" Duplicates: {self.duplicates_pruned} removed "
66 f"({self.duplicate_tokens_saved:,} tokens)"
67 )
69 if self.importance_pruned > 0:
70 lines.append(
71 f" Low importance: {self.importance_pruned} removed "
72 f"({self.importance_tokens_saved:,} tokens)"
73 )
75 return "\n".join(lines)
78@dataclass
79class OutputImportance:
80 """Importance score components for a tool output."""
82 recency_score: float # Exponential decay based on age
83 size_penalty: float # Penalty for large outputs
84 semantic_score: float # Content-based importance (errors, warnings)
85 tool_priority: float # Tool type importance
86 token_count: int # Actual token count
88 @property
89 def total_score(self) -> float:
90 """Calculate weighted total importance score.
92 Higher score = more important = keep longer.
93 Lower score = less important = prune first.
94 """
95 return (
96 self.recency_score
97 + self.size_penalty
98 + self.semantic_score
99 + self.tool_priority
100 )
103def generate_shingles(text: str, n: int = 5) -> set[str]:
104 """Generate word-level n-grams (shingles) for similarity detection.
106 Args:
107 text: Text to generate shingles from
108 n: Size of n-grams (default: 5 words)
110 Returns:
111 Set of n-gram strings
112 """
113 # Normalize text
114 words = text.lower().split()
116 if len(words) < n:
117 # If text too short, use the whole thing
118 return {" ".join(words)} if words else set()
120 # Generate n-grams
121 shingles = set()
122 for i in range(len(words) - n + 1):
123 shingle = " ".join(words[i : i + n])
124 shingles.add(shingle)
126 return shingles
129def jaccard_similarity(set1: set[str], set2: set[str]) -> float:
130 """Calculate Jaccard similarity between two sets.
132 Jaccard similarity = |intersection| / |union|
133 Returns value between 0.0 (no overlap) and 1.0 (identical).
135 Args:
136 set1: First set
137 set2: Second set
139 Returns:
140 Similarity score [0.0, 1.0]
141 """
142 if not set1 or not set2:
143 return 0.0
145 intersection = len(set1 & set2)
146 union = len(set1 | set2)
148 if union == 0:
149 return 0.0
151 return intersection / union
154def compute_content_hash(text: str) -> str:
155 """Compute fast hash of text content for exact duplicate detection.
157 Args:
158 text: Text to hash
160 Returns:
161 MD5 hash hex string
162 """
163 return hashlib.md5(text.encode("utf-8")).hexdigest()
166def find_duplicate_output(
167 part: ToolPart,
168 recent_parts: list[tuple[ToolPart, set[str]]],
169 similarity_threshold: float = 0.8,
170) -> ToolPart | None:
171 """Find if output is duplicate of a recent output.
173 Args:
174 part: Tool part to check for duplication
175 recent_parts: List of (part, shingles) tuples to compare against
176 similarity_threshold: Minimum similarity to consider duplicate (default: 0.8)
178 Returns:
179 The duplicate part if found, None otherwise
180 """
181 # Only compare against same tool type
182 # Different tools are unlikely to produce identical outputs
183 # This prevents false positives across tool types
184 same_tool_parts = [
185 (p, shingles) for p, shingles in recent_parts if p.tool == part.tool
186 ]
188 if not same_tool_parts:
189 return None
191 # Fast path: exact duplicate check (within same tool type)
192 current_hash = compute_content_hash(part.state.output)
194 for recent_part, _ in same_tool_parts:
195 if compute_content_hash(recent_part.state.output) == current_hash:
196 return recent_part # Exact duplicate
198 # Generate shingles for current part
199 current_shingles = generate_shingles(part.state.output)
201 # Check similarity against recent parts of same tool
202 for recent_part, recent_shingles in same_tool_parts:
203 similarity = jaccard_similarity(current_shingles, recent_shingles)
205 if similarity >= similarity_threshold:
206 return recent_part # Found similar output
208 return None
211def calculate_context_tokens(messages: list[Message]) -> int:
212 """Calculate total token count for all tool outputs in conversation.
214 Args:
215 messages: All conversation messages
217 Returns:
218 Total token count
219 """
220 total = 0
221 for msg in messages:
222 for part in msg.parts:
223 if isinstance(part, ToolPart) and part.state.status == "completed":
224 if part.state.output: # Only count non-empty outputs
225 total += count_tokens_fast(part.state.output)
226 return total
229def calculate_turn_age(messages: list[Message], target_msg: Message) -> int:
230 """Calculate how many turns ago a message was created.
232 Args:
233 messages: All conversation messages
234 target_msg: Message to calculate age for
236 Returns:
237 Number of user turns since this message (0 = current turn)
238 """
239 turn_count = 0
240 for msg in reversed(messages):
241 if msg.id == target_msg.id:
242 return turn_count
243 if msg.role == "user":
244 turn_count += 1
245 return turn_count
248def score_tool_output(
249 part: ToolPart,
250 message: Message,
251 messages: list[Message],
252 config: PruningConfig,
253) -> OutputImportance:
254 """Calculate importance score for a tool output.
256 Args:
257 part: Tool part to score
258 message: Message containing this part
259 messages: All conversation messages
260 config: Pruning configuration with scoring weights
262 Returns:
263 OutputImportance with all score components
264 """
265 token_count = count_tokens_fast(part.state.output)
267 # 1. Recency score (exponential decay)
268 age = calculate_turn_age(messages, message)
269 recency = math.exp(-age * config.recency_decay) * 100 * config.recency_weight
271 # 2. Size penalty (prefer removing large outputs)
272 size_penalty = math.log(token_count + 1) * 10 * config.size_weight
274 # 3. Semantic importance
275 semantic = 0.0
276 output_lower = part.state.output.lower()
278 # Check for errors
279 if any(
280 keyword in output_lower
281 for keyword in ["error", "exception", "traceback", "failed"]
282 ):
283 semantic += config.error_boost
285 # Check for warnings
286 if "warning" in output_lower or "warn" in output_lower:
287 semantic += config.warning_boost
289 # Check if user explicitly requested this
290 if part.state.metadata and part.state.metadata.get("user_requested"):
291 semantic += config.user_requested_boost
293 semantic *= config.semantic_weight
295 # 4. Tool type priority
296 tool_priority = config.tool_importance.get(part.tool, 50.0) * config.tool_priority_weight
298 return OutputImportance(
299 recency_score=recency,
300 size_penalty=size_penalty,
301 semantic_score=semantic,
302 tool_priority=tool_priority,
303 token_count=token_count,
304 )
307def prune_tool_outputs(
308 messages: list[Message],
309 config: PruningConfig,
310) -> PruningResult:
311 """Prune tool outputs from conversation history.
313 Selectively removes old tool outputs while preserving:
314 - Tool call metadata (name, input, title, timing)
315 - Recent outputs (within protection window)
316 - Protected tool outputs
317 - Last N turns
319 Uses exact token counting via tiktoken for accurate pruning decisions.
321 First runs duplicate detection (if enabled), then applies either
322 importance-based or FIFO pruning strategy.
324 Args:
325 messages: Conversation messages (newest first recommended)
326 config: Pruning configuration
328 Returns:
329 PruningResult with detailed token tracking and breakdown
330 """
331 # Calculate token usage before pruning
332 tokens_before = calculate_context_tokens(messages)
334 # Phase 1: Duplicate detection (runs regardless of scoring strategy)
335 duplicate_result = PruningResult(pruned=0, tokens_saved=0)
336 if config.detect_duplicates:
337 duplicate_result = _detect_and_prune_duplicates(messages, config)
339 # Phase 2: Importance-based or FIFO pruning
340 if config.use_importance_scoring:
341 pruning_result = _prune_with_importance_scoring_only(messages, config)
342 else:
343 pruning_result = _prune_simple(messages, config)
345 # Calculate token usage after pruning
346 tokens_after = calculate_context_tokens(messages)
348 # Combine results with detailed tracking
349 return PruningResult(
350 pruned=duplicate_result.pruned + pruning_result.pruned,
351 tokens_saved=duplicate_result.tokens_saved + pruning_result.tokens_saved,
352 tokens_before=tokens_before,
353 tokens_after=tokens_after,
354 duplicates_pruned=duplicate_result.pruned,
355 importance_pruned=pruning_result.pruned,
356 duplicate_tokens_saved=duplicate_result.tokens_saved,
357 importance_tokens_saved=pruning_result.tokens_saved,
358 )
361def _prune_simple(
362 messages: list[Message],
363 config: PruningConfig,
364) -> PruningResult:
365 """Simple FIFO pruning (original algorithm).
367 Prunes oldest outputs first when token budget exceeded.
368 """
369 total_tokens = 0
370 prunable_tokens = 0
371 to_prune: list[tuple[Message, ToolPart]] = []
372 turns_skipped = 0
374 for msg in reversed(messages):
375 if msg.role == "user":
376 turns_skipped += 1
378 if turns_skipped < config.protect_turns:
379 continue
381 if msg.summary:
382 break
384 for part in msg.parts:
385 if not isinstance(part, ToolPart):
386 continue
388 if part.state.status != "completed":
389 continue
391 if part.tool in config.protected_tools:
392 continue
394 if part.state.time and part.state.time.compacted:
395 continue
397 token_count = count_tokens_fast(part.state.output)
398 total_tokens += token_count
400 if total_tokens > config.prune_protect:
401 prunable_tokens += token_count
402 to_prune.append((msg, part))
404 if prunable_tokens > config.prune_minimum:
405 for msg, part in to_prune:
406 part.state.output = ""
407 part.state.attachments = []
408 if part.state.time:
409 import time
411 part.state.time.compacted = int(time.time() * 1000)
413 return PruningResult(pruned=len(to_prune), tokens_saved=prunable_tokens)
415 return PruningResult(pruned=0, tokens_saved=0)
418def _prune_with_importance_scoring_only(
419 messages: list[Message],
420 config: PruningConfig,
421) -> PruningResult:
422 """Smart pruning using importance scoring.
424 Scores all outputs by importance and prunes lowest-value first.
425 Preserves critical outputs (errors, warnings, user-requested).
427 Note: Duplicate detection happens separately before this function is called.
428 """
429 # Collect all prunable outputs with scores
430 scored_outputs: list[tuple[Message, ToolPart, OutputImportance]] = []
431 turns_skipped = 0
433 for msg in reversed(messages):
434 if msg.role == "user":
435 turns_skipped += 1
437 if turns_skipped < config.protect_turns:
438 continue
440 if msg.summary:
441 break
443 for part in msg.parts:
444 if not isinstance(part, ToolPart):
445 continue
447 if part.state.status != "completed":
448 continue
450 if part.tool in config.protected_tools:
451 continue
453 if part.state.time and part.state.time.compacted:
454 continue
456 # Score this output
457 importance = score_tool_output(part, msg, messages, config)
458 scored_outputs.append((msg, part, importance))
460 # Calculate total tokens
461 total_tokens = sum(imp.token_count for _, _, imp in scored_outputs)
463 # If under budget, nothing to prune
464 if total_tokens <= config.prune_protect:
465 return PruningResult(pruned=0, tokens_saved=0)
467 # Sort by importance (lowest score first = prune first)
468 scored_outputs.sort(key=lambda x: x[2].total_score)
470 # Prune until we're under budget
471 pruned_outputs: list[tuple[Message, ToolPart, int]] = []
472 current_tokens = total_tokens
474 for msg, part, importance in scored_outputs:
475 if current_tokens <= config.prune_protect:
476 break
478 # Prune this output
479 pruned_outputs.append((msg, part, importance.token_count))
480 current_tokens -= importance.token_count
482 # Only prune if savings meet minimum threshold
483 tokens_saved = sum(tokens for _, _, tokens in pruned_outputs)
485 if tokens_saved >= config.prune_minimum:
486 for msg, part, _ in pruned_outputs:
487 part.state.output = ""
488 part.state.attachments = []
489 if part.state.time:
490 import time
492 part.state.time.compacted = int(time.time() * 1000)
494 return PruningResult(pruned=len(pruned_outputs), tokens_saved=tokens_saved)
496 return PruningResult(pruned=0, tokens_saved=0)
499def _detect_and_prune_duplicates(
500 messages: list[Message],
501 config: PruningConfig,
502) -> PruningResult:
503 """Detect and prune duplicate tool outputs.
505 Uses similarity hashing (shingles + Jaccard) to find near-duplicates.
506 Aggressively prunes older duplicates while keeping the most recent.
508 Args:
509 messages: Conversation messages
510 config: Pruning configuration
512 Returns:
513 PruningResult with duplicates pruned
514 """
515 # Build index of recent outputs with their shingles
516 recent_outputs: list[tuple[ToolPart, set[str], Message]] = []
517 duplicates_to_prune: list[tuple[Message, ToolPart, int]] = []
518 turns_skipped = 0
520 for msg in reversed(messages): # Newest first
521 if msg.role == "user":
522 turns_skipped += 1
524 if turns_skipped < config.protect_turns:
525 continue
527 if msg.summary:
528 break
530 for part in msg.parts:
531 if not isinstance(part, ToolPart):
532 continue
534 if part.state.status != "completed":
535 continue
537 if part.tool in config.protected_tools:
538 continue
540 if part.state.time and part.state.time.compacted:
541 continue
543 # Check if this is a duplicate
544 if len(recent_outputs) > 0:
545 # Only check against recent outputs (limited lookback)
546 lookback = recent_outputs[-config.duplicate_lookback :]
547 duplicate_of = find_duplicate_output(
548 part,
549 [(p, shingles) for p, shingles, _ in lookback],
550 config.similarity_threshold,
551 )
553 if duplicate_of:
554 # This is a duplicate - mark for pruning
555 token_count = count_tokens_fast(part.state.output)
556 duplicates_to_prune.append((msg, part, token_count))
557 continue # Don't add to recent_outputs
559 # Not a duplicate - add to index
560 shingles = generate_shingles(part.state.output)
561 recent_outputs.append((part, shingles, msg))
563 # Prune duplicates
564 tokens_saved = 0
565 for msg, part, token_count in duplicates_to_prune:
566 part.state.output = ""
567 part.state.attachments = []
568 if part.state.time:
569 import time as time_module
571 part.state.time.compacted = int(time_module.time() * 1000)
573 tokens_saved += token_count
575 return PruningResult(
576 pruned=len(duplicates_to_prune),
577 tokens_saved=tokens_saved,
578 )