Coverage for src / harnessutils / manager.py: 92%
120 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
1"""Main ConversationManager API for harness-utils."""
3import time
4from typing import Any
6from harnessutils.compaction.pruning import prune_tool_outputs
7from harnessutils.compaction.summarization import is_overflow, summarize_conversation
8from harnessutils.compaction.truncation import truncate_output
9from harnessutils.config import HarnessConfig
10from harnessutils.conversion.to_model import to_model_messages
11from harnessutils.models.conversation import Conversation
12from harnessutils.models.message import Message
13from harnessutils.models.usage import Usage
14from harnessutils.storage.memory import MemoryStorage
15from harnessutils.types import LLMClient, StorageBackend
16from harnessutils.utils.ids import generate_id
19class ConversationManager:
20 """Main interface for managing conversations with context window management.
22 Provides high-level API for:
23 - Creating and managing conversations
24 - Adding messages
25 - Automatic context compaction (truncation, pruning, summarization)
26 - Message storage and retrieval
27 """
29 def __init__(
30 self,
31 storage: StorageBackend | None = None,
32 config: HarnessConfig | None = None,
33 ):
34 """Initialize conversation manager.
36 Args:
37 storage: Storage backend (uses in-memory if None)
38 config: Configuration (uses defaults if None)
39 """
40 self.config = config or HarnessConfig()
41 self.storage = storage or MemoryStorage()
42 self._message_cache: dict[str, list[Message]] = {}
44 def create_conversation(
45 self,
46 conversation_id: str | None = None,
47 project_id: str | None = None,
48 ) -> Conversation:
49 """Create a new conversation.
51 Args:
52 conversation_id: Optional conversation ID (generated if None)
53 project_id: Optional project ID for grouping
55 Returns:
56 New conversation object
57 """
58 if conversation_id is None:
59 conversation_id = generate_id("conv")
61 now = int(time.time() * 1000)
62 conversation = Conversation(
63 id=conversation_id,
64 project_id=project_id,
65 created=now,
66 updated=now,
67 )
69 self.storage.save_conversation(conversation_id, conversation.to_dict())
70 self._message_cache[conversation_id] = []
72 return conversation
74 def add_message(self, conversation_id: str, message: Message) -> None:
75 """Add a message to a conversation.
77 Args:
78 conversation_id: Conversation to add message to
79 message: Message to add
80 """
81 self.storage.save_message(conversation_id, message.id, message.to_dict())
83 if conversation_id not in self._message_cache:
84 self._message_cache[conversation_id] = []
85 self._message_cache[conversation_id].append(message)
87 # Load conversation and update
88 conv_data = self.storage.load_conversation(conversation_id)
89 conv = Conversation.from_dict(conv_data)
90 conv.updated = int(time.time() * 1000)
92 # Track velocity if message has token count
93 if message.tokens:
94 tokens_added = message.tokens.total
95 conv.update_velocity(tokens_added)
97 self.storage.save_conversation(conversation_id, conv.to_dict())
99 def get_messages(self, conversation_id: str) -> list[Message]:
100 """Get all messages for a conversation.
102 Args:
103 conversation_id: Conversation ID
105 Returns:
106 List of messages in chronological order
107 """
108 if conversation_id in self._message_cache:
109 return self._message_cache[conversation_id]
111 message_ids = self.storage.list_messages(conversation_id)
112 messages = [
113 Message.from_dict(self.storage.load_message(conversation_id, msg_id))
114 for msg_id in message_ids
115 ]
117 self._message_cache[conversation_id] = messages
118 return messages
120 def prune_before_turn(
121 self,
122 conversation_id: str,
123 auto_mode: bool = False,
124 ) -> dict[str, Any]:
125 """Proactively prune old tool outputs before processing a turn.
127 This is Tier 2 compaction - removes old tool outputs while
128 preserving conversation structure.
130 Args:
131 conversation_id: Conversation to prune
132 auto_mode: Whether this was auto-triggered
134 Returns:
135 Detailed pruning result with token tracking and breakdown:
136 - pruned: Total outputs removed
137 - tokens_saved: Total tokens saved
138 - tokens_before: Token count before pruning
139 - tokens_after: Token count after pruning
140 - duplicates_pruned: Outputs removed due to duplication
141 - importance_pruned: Outputs removed due to low importance
142 - duplicate_tokens_saved: Tokens saved from deduplication
143 - importance_tokens_saved: Tokens saved from importance pruning
144 - reduction_percent: Percentage reduction in token usage
145 """
146 if not self.config.compaction.prune and auto_mode:
147 return {
148 "pruned": 0,
149 "tokens_saved": 0,
150 "tokens_before": 0,
151 "tokens_after": 0,
152 "duplicates_pruned": 0,
153 "importance_pruned": 0,
154 "duplicate_tokens_saved": 0,
155 "importance_tokens_saved": 0,
156 "reduction_percent": 0,
157 }
159 messages = self.get_messages(conversation_id)
160 result = prune_tool_outputs(
161 messages,
162 self.config.pruning,
163 )
165 for msg in messages:
166 self.storage.save_message(conversation_id, msg.id, msg.to_dict())
168 return result.to_dict()
170 def predict_overflow(
171 self,
172 conversation_id: str,
173 current_usage: Usage,
174 ) -> bool:
175 """Predict if conversation will overflow in next N turns.
177 Args:
178 conversation_id: Conversation to check
179 current_usage: Current token usage
181 Returns:
182 True if overflow predicted within lookahead window
183 """
184 if not self.config.compaction.use_predictive:
185 return False
187 # Load conversation and get velocity
188 conv_data = self.storage.load_conversation(conversation_id)
189 conv = Conversation.from_dict(conv_data)
190 velocity = conv.get_velocity()
192 if velocity is None or not velocity.turn_deltas:
193 return False # No velocity data yet
195 # Project tokens ahead
196 lookahead = self.config.compaction.predictive_lookahead
197 predicted_growth = velocity.predict_tokens_ahead(lookahead)
199 # Calculate current total and projected total
200 current_total = current_usage.input + current_usage.cache.read
201 projected_total = current_total + predicted_growth
203 # Check against safety margin
204 usable_space = (
205 self.config.model_limits.default_context_limit
206 - self.config.model_limits.default_output_limit
207 )
208 safety_threshold = usable_space * self.config.compaction.predictive_safety_margin
210 return projected_total > safety_threshold
212 def needs_compaction(
213 self,
214 conversation_id: str,
215 usage: Usage,
216 ) -> bool:
217 """Check if conversation needs summarization (Tier 3).
219 Uses both reactive (overflow) and predictive checks.
221 Args:
222 conversation_id: Conversation to check
223 usage: Token usage from last turn
225 Returns:
226 True if summarization needed
227 """
228 # Reactive check: already overflowed
229 if is_overflow(
230 usage,
231 self.config.model_limits.default_context_limit,
232 self.config.model_limits.default_output_limit,
233 ):
234 return True
236 # Predictive check: will overflow soon
237 return self.predict_overflow(conversation_id, usage)
239 def compact(
240 self,
241 conversation_id: str,
242 llm_client: LLMClient,
243 parent_message_id: str,
244 model: str | None = None,
245 auto_mode: bool = False,
246 ) -> dict[str, Any]:
247 """Compact conversation using LLM summarization (Tier 3).
249 Args:
250 conversation_id: Conversation to compact
251 llm_client: LLM client for summarization
252 parent_message_id: Message that triggered compaction
253 model: Optional model to use for summarization
254 auto_mode: Whether this was auto-triggered
256 Returns:
257 Compaction result with summary message and metrics
258 """
259 if not self.config.compaction.auto and auto_mode:
260 return {"summarized": False}
262 messages = self.get_messages(conversation_id)
263 summary_id = generate_id("msg")
265 result = summarize_conversation(
266 messages=messages,
267 llm_client=llm_client,
268 parent_message_id=parent_message_id,
269 message_id=summary_id,
270 model=model,
271 auto_mode=auto_mode,
272 config=self.config.summarization,
273 )
275 self.add_message(conversation_id, result.summary_message)
277 return {
278 "summarized": True,
279 "summary_message_id": summary_id,
280 "tokens_used": result.tokens_used.total,
281 "cost": result.cost,
282 }
284 def to_model_format(self, conversation_id: str) -> list[dict[str, Any]]:
285 """Convert conversation messages to model format for LLM requests.
287 Args:
288 conversation_id: Conversation to convert
290 Returns:
291 List of messages in model format
292 """
293 messages = self.get_messages(conversation_id)
294 return to_model_messages(messages)
296 def calculate_context_usage(
297 self,
298 conversation_id: str,
299 model: str = "claude-3-5-sonnet-20241022",
300 ) -> int:
301 """Calculate exact token count for conversation using tiktoken.
303 This counts ALL tokens in the conversation (user messages, assistant
304 responses, tool outputs, etc.) that will be sent to the model.
306 Args:
307 conversation_id: Conversation to calculate usage for
308 model: Model name for tokenizer selection
310 Returns:
311 Exact token count that will be used in context window
312 """
313 from harnessutils.tokens.exact import count_tokens_exact
315 messages = self.to_model_format(conversation_id)
316 return count_tokens_exact(messages, model)
318 def get_tool_output_tokens(self, conversation_id: str) -> dict[str, Any]:
319 """Get detailed breakdown of token usage for tool outputs.
321 Args:
322 conversation_id: Conversation to analyze
324 Returns:
325 Dictionary with token breakdown:
326 - total: Total tokens in tool outputs
327 - by_tool: Token count per tool type
328 - prunable: Tokens that could be pruned
329 - protected: Tokens in protected outputs
330 """
331 from harnessutils.compaction.pruning import calculate_context_tokens
333 messages = self.get_messages(conversation_id)
335 total = calculate_context_tokens(messages)
336 by_tool: dict[str, int] = {}
337 prunable = 0
338 protected = 0
339 turns_skipped = 0
341 for msg in reversed(messages):
342 if msg.role == "user":
343 turns_skipped += 1
345 for part in msg.parts:
346 from harnessutils.models.parts import ToolPart
347 from harnessutils.tokens.exact import count_tokens_fast
349 if not isinstance(part, ToolPart):
350 continue
352 if part.state.status != "completed":
353 continue
355 if not part.state.output:
356 continue
358 tokens = count_tokens_fast(part.state.output)
360 # Track by tool type
361 by_tool[part.tool] = by_tool.get(part.tool, 0) + tokens
363 # Determine if prunable
364 is_protected = (
365 turns_skipped < self.config.pruning.protect_turns
366 or part.tool in self.config.pruning.protected_tools
367 or (part.state.time and part.state.time.compacted)
368 )
370 if is_protected:
371 protected += tokens
372 else:
373 prunable += tokens
375 return {
376 "total": total,
377 "by_tool": by_tool,
378 "prunable": prunable,
379 "protected": protected,
380 "prunability_percent": round((prunable / total * 100) if total > 0 else 0, 1),
381 }
383 def truncate_tool_output(
384 self,
385 output: str,
386 tool_name: str,
387 ) -> str:
388 """Truncate tool output if it exceeds limits (Tier 1).
390 Args:
391 output: Tool output to truncate
392 tool_name: Name of the tool
394 Returns:
395 Potentially truncated output
396 """
397 output_id = generate_id(f"output_{tool_name}")
399 result = truncate_output(
400 output=output,
401 config=self.config.truncation,
402 output_id=output_id,
403 )
405 if result.truncated and result.output_path:
406 self.storage.save_truncated_output(result.output_path, output)
408 return result.content