Coverage for src / harnessutils / manager.py: 92%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-12 22:41 -0600

1"""Main ConversationManager API for harness-utils.""" 

2 

3import time 

4from typing import Any 

5 

6from harnessutils.compaction.pruning import prune_tool_outputs 

7from harnessutils.compaction.summarization import is_overflow, summarize_conversation 

8from harnessutils.compaction.truncation import truncate_output 

9from harnessutils.config import HarnessConfig 

10from harnessutils.conversion.to_model import to_model_messages 

11from harnessutils.models.conversation import Conversation 

12from harnessutils.models.message import Message 

13from harnessutils.models.usage import Usage 

14from harnessutils.storage.memory import MemoryStorage 

15from harnessutils.types import LLMClient, StorageBackend 

16from harnessutils.utils.ids import generate_id 

17 

18 

19class ConversationManager: 

20 """Main interface for managing conversations with context window management. 

21 

22 Provides high-level API for: 

23 - Creating and managing conversations 

24 - Adding messages 

25 - Automatic context compaction (truncation, pruning, summarization) 

26 - Message storage and retrieval 

27 """ 

28 

29 def __init__( 

30 self, 

31 storage: StorageBackend | None = None, 

32 config: HarnessConfig | None = None, 

33 ): 

34 """Initialize conversation manager. 

35 

36 Args: 

37 storage: Storage backend (uses in-memory if None) 

38 config: Configuration (uses defaults if None) 

39 """ 

40 self.config = config or HarnessConfig() 

41 self.storage = storage or MemoryStorage() 

42 self._message_cache: dict[str, list[Message]] = {} 

43 

44 def create_conversation( 

45 self, 

46 conversation_id: str | None = None, 

47 project_id: str | None = None, 

48 ) -> Conversation: 

49 """Create a new conversation. 

50 

51 Args: 

52 conversation_id: Optional conversation ID (generated if None) 

53 project_id: Optional project ID for grouping 

54 

55 Returns: 

56 New conversation object 

57 """ 

58 if conversation_id is None: 

59 conversation_id = generate_id("conv") 

60 

61 now = int(time.time() * 1000) 

62 conversation = Conversation( 

63 id=conversation_id, 

64 project_id=project_id, 

65 created=now, 

66 updated=now, 

67 ) 

68 

69 self.storage.save_conversation(conversation_id, conversation.to_dict()) 

70 self._message_cache[conversation_id] = [] 

71 

72 return conversation 

73 

74 def add_message(self, conversation_id: str, message: Message) -> None: 

75 """Add a message to a conversation. 

76 

77 Args: 

78 conversation_id: Conversation to add message to 

79 message: Message to add 

80 """ 

81 self.storage.save_message(conversation_id, message.id, message.to_dict()) 

82 

83 if conversation_id not in self._message_cache: 

84 self._message_cache[conversation_id] = [] 

85 self._message_cache[conversation_id].append(message) 

86 

87 # Load conversation and update 

88 conv_data = self.storage.load_conversation(conversation_id) 

89 conv = Conversation.from_dict(conv_data) 

90 conv.updated = int(time.time() * 1000) 

91 

92 # Track velocity if message has token count 

93 if message.tokens: 

94 tokens_added = message.tokens.total 

95 conv.update_velocity(tokens_added) 

96 

97 self.storage.save_conversation(conversation_id, conv.to_dict()) 

98 

99 def get_messages(self, conversation_id: str) -> list[Message]: 

100 """Get all messages for a conversation. 

101 

102 Args: 

103 conversation_id: Conversation ID 

104 

105 Returns: 

106 List of messages in chronological order 

107 """ 

108 if conversation_id in self._message_cache: 

109 return self._message_cache[conversation_id] 

110 

111 message_ids = self.storage.list_messages(conversation_id) 

112 messages = [ 

113 Message.from_dict(self.storage.load_message(conversation_id, msg_id)) 

114 for msg_id in message_ids 

115 ] 

116 

117 self._message_cache[conversation_id] = messages 

118 return messages 

119 

120 def prune_before_turn( 

121 self, 

122 conversation_id: str, 

123 auto_mode: bool = False, 

124 ) -> dict[str, Any]: 

125 """Proactively prune old tool outputs before processing a turn. 

126 

127 This is Tier 2 compaction - removes old tool outputs while 

128 preserving conversation structure. 

129 

130 Args: 

131 conversation_id: Conversation to prune 

132 auto_mode: Whether this was auto-triggered 

133 

134 Returns: 

135 Detailed pruning result with token tracking and breakdown: 

136 - pruned: Total outputs removed 

137 - tokens_saved: Total tokens saved 

138 - tokens_before: Token count before pruning 

139 - tokens_after: Token count after pruning 

140 - duplicates_pruned: Outputs removed due to duplication 

141 - importance_pruned: Outputs removed due to low importance 

142 - duplicate_tokens_saved: Tokens saved from deduplication 

143 - importance_tokens_saved: Tokens saved from importance pruning 

144 - reduction_percent: Percentage reduction in token usage 

145 """ 

146 if not self.config.compaction.prune and auto_mode: 

147 return { 

148 "pruned": 0, 

149 "tokens_saved": 0, 

150 "tokens_before": 0, 

151 "tokens_after": 0, 

152 "duplicates_pruned": 0, 

153 "importance_pruned": 0, 

154 "duplicate_tokens_saved": 0, 

155 "importance_tokens_saved": 0, 

156 "reduction_percent": 0, 

157 } 

158 

159 messages = self.get_messages(conversation_id) 

160 result = prune_tool_outputs( 

161 messages, 

162 self.config.pruning, 

163 ) 

164 

165 for msg in messages: 

166 self.storage.save_message(conversation_id, msg.id, msg.to_dict()) 

167 

168 return result.to_dict() 

169 

170 def predict_overflow( 

171 self, 

172 conversation_id: str, 

173 current_usage: Usage, 

174 ) -> bool: 

175 """Predict if conversation will overflow in next N turns. 

176 

177 Args: 

178 conversation_id: Conversation to check 

179 current_usage: Current token usage 

180 

181 Returns: 

182 True if overflow predicted within lookahead window 

183 """ 

184 if not self.config.compaction.use_predictive: 

185 return False 

186 

187 # Load conversation and get velocity 

188 conv_data = self.storage.load_conversation(conversation_id) 

189 conv = Conversation.from_dict(conv_data) 

190 velocity = conv.get_velocity() 

191 

192 if velocity is None or not velocity.turn_deltas: 

193 return False # No velocity data yet 

194 

195 # Project tokens ahead 

196 lookahead = self.config.compaction.predictive_lookahead 

197 predicted_growth = velocity.predict_tokens_ahead(lookahead) 

198 

199 # Calculate current total and projected total 

200 current_total = current_usage.input + current_usage.cache.read 

201 projected_total = current_total + predicted_growth 

202 

203 # Check against safety margin 

204 usable_space = ( 

205 self.config.model_limits.default_context_limit 

206 - self.config.model_limits.default_output_limit 

207 ) 

208 safety_threshold = usable_space * self.config.compaction.predictive_safety_margin 

209 

210 return projected_total > safety_threshold 

211 

212 def needs_compaction( 

213 self, 

214 conversation_id: str, 

215 usage: Usage, 

216 ) -> bool: 

217 """Check if conversation needs summarization (Tier 3). 

218 

219 Uses both reactive (overflow) and predictive checks. 

220 

221 Args: 

222 conversation_id: Conversation to check 

223 usage: Token usage from last turn 

224 

225 Returns: 

226 True if summarization needed 

227 """ 

228 # Reactive check: already overflowed 

229 if is_overflow( 

230 usage, 

231 self.config.model_limits.default_context_limit, 

232 self.config.model_limits.default_output_limit, 

233 ): 

234 return True 

235 

236 # Predictive check: will overflow soon 

237 return self.predict_overflow(conversation_id, usage) 

238 

239 def compact( 

240 self, 

241 conversation_id: str, 

242 llm_client: LLMClient, 

243 parent_message_id: str, 

244 model: str | None = None, 

245 auto_mode: bool = False, 

246 ) -> dict[str, Any]: 

247 """Compact conversation using LLM summarization (Tier 3). 

248 

249 Args: 

250 conversation_id: Conversation to compact 

251 llm_client: LLM client for summarization 

252 parent_message_id: Message that triggered compaction 

253 model: Optional model to use for summarization 

254 auto_mode: Whether this was auto-triggered 

255 

256 Returns: 

257 Compaction result with summary message and metrics 

258 """ 

259 if not self.config.compaction.auto and auto_mode: 

260 return {"summarized": False} 

261 

262 messages = self.get_messages(conversation_id) 

263 summary_id = generate_id("msg") 

264 

265 result = summarize_conversation( 

266 messages=messages, 

267 llm_client=llm_client, 

268 parent_message_id=parent_message_id, 

269 message_id=summary_id, 

270 model=model, 

271 auto_mode=auto_mode, 

272 config=self.config.summarization, 

273 ) 

274 

275 self.add_message(conversation_id, result.summary_message) 

276 

277 return { 

278 "summarized": True, 

279 "summary_message_id": summary_id, 

280 "tokens_used": result.tokens_used.total, 

281 "cost": result.cost, 

282 } 

283 

284 def to_model_format(self, conversation_id: str) -> list[dict[str, Any]]: 

285 """Convert conversation messages to model format for LLM requests. 

286 

287 Args: 

288 conversation_id: Conversation to convert 

289 

290 Returns: 

291 List of messages in model format 

292 """ 

293 messages = self.get_messages(conversation_id) 

294 return to_model_messages(messages) 

295 

296 def calculate_context_usage( 

297 self, 

298 conversation_id: str, 

299 model: str = "claude-3-5-sonnet-20241022", 

300 ) -> int: 

301 """Calculate exact token count for conversation using tiktoken. 

302 

303 This counts ALL tokens in the conversation (user messages, assistant 

304 responses, tool outputs, etc.) that will be sent to the model. 

305 

306 Args: 

307 conversation_id: Conversation to calculate usage for 

308 model: Model name for tokenizer selection 

309 

310 Returns: 

311 Exact token count that will be used in context window 

312 """ 

313 from harnessutils.tokens.exact import count_tokens_exact 

314 

315 messages = self.to_model_format(conversation_id) 

316 return count_tokens_exact(messages, model) 

317 

318 def get_tool_output_tokens(self, conversation_id: str) -> dict[str, Any]: 

319 """Get detailed breakdown of token usage for tool outputs. 

320 

321 Args: 

322 conversation_id: Conversation to analyze 

323 

324 Returns: 

325 Dictionary with token breakdown: 

326 - total: Total tokens in tool outputs 

327 - by_tool: Token count per tool type 

328 - prunable: Tokens that could be pruned 

329 - protected: Tokens in protected outputs 

330 """ 

331 from harnessutils.compaction.pruning import calculate_context_tokens 

332 

333 messages = self.get_messages(conversation_id) 

334 

335 total = calculate_context_tokens(messages) 

336 by_tool: dict[str, int] = {} 

337 prunable = 0 

338 protected = 0 

339 turns_skipped = 0 

340 

341 for msg in reversed(messages): 

342 if msg.role == "user": 

343 turns_skipped += 1 

344 

345 for part in msg.parts: 

346 from harnessutils.models.parts import ToolPart 

347 from harnessutils.tokens.exact import count_tokens_fast 

348 

349 if not isinstance(part, ToolPart): 

350 continue 

351 

352 if part.state.status != "completed": 

353 continue 

354 

355 if not part.state.output: 

356 continue 

357 

358 tokens = count_tokens_fast(part.state.output) 

359 

360 # Track by tool type 

361 by_tool[part.tool] = by_tool.get(part.tool, 0) + tokens 

362 

363 # Determine if prunable 

364 is_protected = ( 

365 turns_skipped < self.config.pruning.protect_turns 

366 or part.tool in self.config.pruning.protected_tools 

367 or (part.state.time and part.state.time.compacted) 

368 ) 

369 

370 if is_protected: 

371 protected += tokens 

372 else: 

373 prunable += tokens 

374 

375 return { 

376 "total": total, 

377 "by_tool": by_tool, 

378 "prunable": prunable, 

379 "protected": protected, 

380 "prunability_percent": round((prunable / total * 100) if total > 0 else 0, 1), 

381 } 

382 

383 def truncate_tool_output( 

384 self, 

385 output: str, 

386 tool_name: str, 

387 ) -> str: 

388 """Truncate tool output if it exceeds limits (Tier 1). 

389 

390 Args: 

391 output: Tool output to truncate 

392 tool_name: Name of the tool 

393 

394 Returns: 

395 Potentially truncated output 

396 """ 

397 output_id = generate_id(f"output_{tool_name}") 

398 

399 result = truncate_output( 

400 output=output, 

401 config=self.config.truncation, 

402 output_id=output_id, 

403 ) 

404 

405 if result.truncated and result.output_path: 

406 self.storage.save_truncated_output(result.output_path, output) 

407 

408 return result.content