Coverage for src / harnessutils / snapshots.py: 91%

150 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-18 10:56 -0600

1"""Context snapshot management for reproducibility and debugging. 

2 

3Snapshots capture full conversation state at a point in time, enabling: 

4- Reproducible debugging (save state before/after changes) 

5- A/B testing (compare different context strategies) 

6- Time travel (restore to previous state) 

7- Version control (export snapshots to git) 

8""" 

9 

10import json 

11import time 

12from dataclasses import dataclass, field 

13from typing import Any 

14 

15from harnessutils.models.conversation import Conversation 

16from harnessutils.models.message import Message 

17 

18 

19@dataclass 

20class Snapshot: 

21 """Complete snapshot of conversation state. 

22 

23 Captures everything needed to restore conversation to exact state: 

24 - All messages with full content 

25 - Conversation metadata (velocity, etc.) 

26 - Configuration used 

27 - Snapshot metadata (when, why) 

28 """ 

29 

30 snapshot_id: str 

31 conversation_id: str 

32 timestamp: int # Unix ms 

33 messages: list[dict[str, Any]] # Serialized messages 

34 conversation: dict[str, Any] # Serialized conversation 

35 config: dict[str, Any] # Configuration snapshot 

36 metadata: dict[str, Any] = field(default_factory=dict) 

37 

38 def to_dict(self) -> dict[str, Any]: 

39 """Serialize snapshot to dictionary. 

40 

41 Returns: 

42 Dictionary representation for JSON export 

43 """ 

44 return { 

45 "snapshot_id": self.snapshot_id, 

46 "conversation_id": self.conversation_id, 

47 "timestamp": self.timestamp, 

48 "messages": self.messages, 

49 "conversation": self.conversation, 

50 "config": self.config, 

51 "metadata": self.metadata, 

52 } 

53 

54 @classmethod 

55 def from_dict(cls, data: dict[str, Any]) -> "Snapshot": 

56 """Deserialize snapshot from dictionary. 

57 

58 Args: 

59 data: Dictionary representation 

60 

61 Returns: 

62 Snapshot instance 

63 """ 

64 return cls( 

65 snapshot_id=data["snapshot_id"], 

66 conversation_id=data["conversation_id"], 

67 timestamp=data["timestamp"], 

68 messages=data["messages"], 

69 conversation=data["conversation"], 

70 config=data["config"], 

71 metadata=data.get("metadata", {}), 

72 ) 

73 

74 

75@dataclass 

76class SnapshotDiff: 

77 """Difference between two snapshots.""" 

78 

79 messages_added: int 

80 messages_removed: int 

81 tokens_delta: int 

82 message_changes: list[dict[str, Any]] 

83 config_changes: list[dict[str, Any]] 

84 metadata_changes: dict[str, Any] 

85 

86 

87class SnapshotManager: 

88 """Manages conversation snapshots. 

89 

90 Handles snapshot creation, storage, restoration, and comparison. 

91 """ 

92 

93 def __init__(self, storage_backend: Any = None): 

94 """Initialize snapshot manager. 

95 

96 Args: 

97 storage_backend: Optional storage backend for persistence 

98 """ 

99 self.storage = storage_backend 

100 self._snapshots: dict[str, Snapshot] = {} 

101 

102 def create_snapshot( 

103 self, 

104 conversation_id: str, 

105 messages: list[Message], 

106 conversation: Conversation, 

107 config: dict[str, Any], 

108 snapshot_id: str | None = None, 

109 metadata: dict[str, Any] | None = None, 

110 ) -> Snapshot: 

111 """Create snapshot of current conversation state. 

112 

113 Args: 

114 conversation_id: Conversation ID 

115 messages: Current messages 

116 conversation: Conversation metadata 

117 config: Configuration dict 

118 snapshot_id: Optional snapshot ID (auto-generated if None) 

119 metadata: Optional metadata (e.g., {"reason": "before_refactor"}) 

120 

121 Returns: 

122 Created snapshot 

123 """ 

124 from harnessutils.utils.ids import generate_id 

125 

126 snapshot_id = snapshot_id or generate_id("snap") 

127 now_ms = int(time.time() * 1000) 

128 

129 # Serialize messages with parts 

130 from dataclasses import asdict 

131 

132 serialized_messages = [] 

133 for msg in messages: 

134 msg_dict = msg.to_dict() 

135 # Add parts - they're not in to_dict() since normally stored separately 

136 msg_dict["parts"] = [asdict(part) for part in msg.parts] 

137 serialized_messages.append(msg_dict) 

138 

139 # Create snapshot 

140 snapshot = Snapshot( 

141 snapshot_id=snapshot_id, 

142 conversation_id=conversation_id, 

143 timestamp=now_ms, 

144 messages=serialized_messages, 

145 conversation=conversation.to_dict(), 

146 config=config, 

147 metadata=metadata or {}, 

148 ) 

149 

150 # Store in memory 

151 self._snapshots[snapshot_id] = snapshot 

152 

153 # Optionally persist to storage 

154 if self.storage: 

155 self.storage.save_snapshot(snapshot_id, snapshot.to_dict()) 

156 

157 return snapshot 

158 

159 def get_snapshot(self, snapshot_id: str) -> Snapshot | None: 

160 """Retrieve snapshot by ID. 

161 

162 Args: 

163 snapshot_id: Snapshot ID 

164 

165 Returns: 

166 Snapshot if found, None otherwise 

167 """ 

168 # Check memory first 

169 if snapshot_id in self._snapshots: 

170 return self._snapshots[snapshot_id] 

171 

172 # Try storage 

173 if self.storage: 

174 try: 

175 data = self.storage.load_snapshot(snapshot_id) 

176 snapshot = Snapshot.from_dict(data) 

177 self._snapshots[snapshot_id] = snapshot 

178 return snapshot 

179 except (KeyError, FileNotFoundError): 

180 pass 

181 

182 return None 

183 

184 def restore_snapshot(self, snapshot: Snapshot) -> tuple[list[Message], Conversation]: 

185 """Restore messages and conversation from snapshot. 

186 

187 Args: 

188 snapshot: Snapshot to restore 

189 

190 Returns: 

191 Tuple of (messages, conversation) 

192 """ 

193 from harnessutils.models.parts import ( 

194 CompactionPart, 

195 PatchPart, 

196 ReasoningPart, 

197 StepFinishPart, 

198 StepStartPart, 

199 SubtaskPart, 

200 TextPart, 

201 TimeInfo, 

202 ToolPart, 

203 ToolState, 

204 ) 

205 

206 def _deserialize_part(part_data: dict[str, Any]) -> Any: 

207 """Deserialize a part from dict.""" 

208 part_type = part_data.get("type") 

209 

210 # Extract time if present 

211 time_info = None 

212 if part_data.get("time"): 

213 time_data = part_data["time"] 

214 time_info = TimeInfo( 

215 start=time_data.get("start", 0), 

216 end=time_data.get("end", 0), 

217 compacted=time_data.get("compacted"), 

218 ) 

219 

220 # Common fields 

221 metadata = part_data.get("metadata", {}) 

222 

223 if part_type == "text": 

224 return TextPart( 

225 text=part_data.get("text", ""), 

226 ignored=part_data.get("ignored", False), 

227 time=time_info, 

228 metadata=metadata, 

229 ) 

230 elif part_type == "reasoning": 

231 return ReasoningPart( 

232 text=part_data.get("text", ""), 

233 time=time_info, 

234 metadata=metadata, 

235 ) 

236 elif part_type == "tool": 

237 state_data = part_data.get("state", {}) 

238 state_time = None 

239 if state_data.get("time"): 

240 st = state_data["time"] 

241 state_time = TimeInfo( 

242 start=st.get("start", 0), 

243 end=st.get("end", 0), 

244 compacted=st.get("compacted"), 

245 ) 

246 return ToolPart( 

247 tool=part_data.get("tool", ""), 

248 call_id=part_data.get("call_id", ""), 

249 state=ToolState( 

250 status=state_data.get("status", "pending"), 

251 input=state_data.get("input", {}), 

252 output=state_data.get("output", ""), 

253 title=state_data.get("title", ""), 

254 metadata=state_data.get("metadata", {}), 

255 error=state_data.get("error"), 

256 time=state_time, 

257 attachments=state_data.get("attachments", []), 

258 ), 

259 time=time_info, 

260 metadata=metadata, 

261 ) 

262 elif part_type == "step-start": 

263 return StepStartPart( 

264 snapshot=part_data.get("snapshot", ""), 

265 time=time_info, 

266 metadata=metadata, 

267 ) 

268 elif part_type == "step-finish": 

269 return StepFinishPart( 

270 reason=part_data.get("reason", "stop"), 

271 snapshot=part_data.get("snapshot", ""), 

272 tokens=part_data.get("tokens", {}), 

273 cost=part_data.get("cost", 0.0), 

274 time=time_info, 

275 metadata=metadata, 

276 ) 

277 elif part_type == "compaction": 

278 return CompactionPart( 

279 auto=part_data.get("auto", False), 

280 time=time_info, 

281 metadata=metadata, 

282 ) 

283 elif part_type == "patch": 

284 return PatchPart( 

285 hash=part_data.get("hash", ""), 

286 files=part_data.get("files", []), 

287 time=time_info, 

288 metadata=metadata, 

289 ) 

290 elif part_type == "subtask": 

291 return SubtaskPart( 

292 prompt=part_data.get("prompt", ""), 

293 description=part_data.get("description", ""), 

294 agent=part_data.get("agent", ""), 

295 model=part_data.get("model", {}), 

296 time=time_info, 

297 metadata=metadata, 

298 ) 

299 else: 

300 # Unknown part type - create a basic TextPart 

301 return TextPart(text="", time=time_info, metadata=metadata) 

302 

303 # Deserialize messages 

304 messages = [] 

305 for msg_data in snapshot.messages: 

306 msg = Message.from_dict(msg_data) 

307 # Restore parts 

308 if "parts" in msg_data: 

309 for part_data in msg_data["parts"]: 

310 part = _deserialize_part(part_data) 

311 msg.add_part(part) 

312 messages.append(msg) 

313 

314 # Deserialize conversation 

315 conversation = Conversation.from_dict(snapshot.conversation) 

316 

317 return messages, conversation 

318 

319 def compare_snapshots( 

320 self, snapshot1_id: str, snapshot2_id: str 

321 ) -> SnapshotDiff | None: 

322 """Compare two snapshots to see what changed. 

323 

324 Args: 

325 snapshot1_id: First snapshot ID (earlier) 

326 snapshot2_id: Second snapshot ID (later) 

327 

328 Returns: 

329 SnapshotDiff describing changes, or None if snapshots not found 

330 """ 

331 snap1 = self.get_snapshot(snapshot1_id) 

332 snap2 = self.get_snapshot(snapshot2_id) 

333 

334 if not snap1 or not snap2: 

335 return None 

336 

337 # Compare message counts 

338 messages_added = len(snap2.messages) - len(snap1.messages) 

339 messages_removed = max(0, len(snap1.messages) - len(snap2.messages)) 

340 

341 # Calculate token delta (simplified - just count message changes) 

342 # In a real implementation, would calculate actual token delta 

343 tokens_delta = 0 

344 

345 # Find message changes 

346 message_changes = [] 

347 snap1_ids = {msg["id"] for msg in snap1.messages} 

348 snap2_ids = {msg["id"] for msg in snap2.messages} 

349 

350 added_ids = snap2_ids - snap1_ids 

351 removed_ids = snap1_ids - snap2_ids 

352 

353 if added_ids: 

354 message_changes.append({ 

355 "type": "added", 

356 "count": len(added_ids), 

357 "ids": list(added_ids), 

358 }) 

359 

360 if removed_ids: 

361 message_changes.append({ 

362 "type": "removed", 

363 "count": len(removed_ids), 

364 "ids": list(removed_ids), 

365 }) 

366 

367 # Compare configs (simplified) 

368 config_changes = [] 

369 if snap1.config != snap2.config: 

370 config_changes.append({ 

371 "type": "config_modified", 

372 "before": snap1.config, 

373 "after": snap2.config, 

374 }) 

375 

376 # Metadata changes 

377 metadata_changes = { 

378 "snapshot1_timestamp": snap1.timestamp, 

379 "snapshot2_timestamp": snap2.timestamp, 

380 "time_delta_ms": snap2.timestamp - snap1.timestamp, 

381 } 

382 

383 return SnapshotDiff( 

384 messages_added=max(0, messages_added), 

385 messages_removed=messages_removed, 

386 tokens_delta=tokens_delta, 

387 message_changes=message_changes, 

388 config_changes=config_changes, 

389 metadata_changes=metadata_changes, 

390 ) 

391 

392 def export_snapshot(self, snapshot_id: str, file_path: str) -> None: 

393 """Export snapshot to JSON file. 

394 

395 Useful for: 

396 - Version control (commit snapshots to git) 

397 - Sharing test cases 

398 - Long-term archival 

399 

400 Args: 

401 snapshot_id: Snapshot to export 

402 file_path: Path to write JSON file 

403 

404 Raises: 

405 FileNotFoundError: If snapshot not found 

406 """ 

407 snapshot = self.get_snapshot(snapshot_id) 

408 if not snapshot: 

409 raise FileNotFoundError(f"Snapshot {snapshot_id} not found") 

410 

411 with open(file_path, "w") as f: 

412 json.dump(snapshot.to_dict(), f, indent=2) 

413 

414 def import_snapshot(self, file_path: str) -> Snapshot: 

415 """Import snapshot from JSON file. 

416 

417 Args: 

418 file_path: Path to JSON file 

419 

420 Returns: 

421 Imported snapshot 

422 """ 

423 with open(file_path) as f: 

424 data = json.load(f) 

425 

426 snapshot = Snapshot.from_dict(data) 

427 self._snapshots[snapshot.snapshot_id] = snapshot 

428 

429 return snapshot 

430 

431 def list_snapshots(self, conversation_id: str | None = None) -> list[Snapshot]: 

432 """List all snapshots, optionally filtered by conversation. 

433 

434 Args: 

435 conversation_id: Optional conversation ID filter 

436 

437 Returns: 

438 List of snapshots (sorted by timestamp, newest first) 

439 """ 

440 snapshots = list(self._snapshots.values()) 

441 

442 if conversation_id: 

443 snapshots = [s for s in snapshots if s.conversation_id == conversation_id] 

444 

445 # Sort by timestamp, newest first 

446 snapshots.sort(key=lambda s: s.timestamp, reverse=True) 

447 

448 return snapshots 

449 

450 def delete_snapshot(self, snapshot_id: str) -> bool: 

451 """Delete snapshot. 

452 

453 Args: 

454 snapshot_id: Snapshot to delete 

455 

456 Returns: 

457 True if deleted, False if not found 

458 """ 

459 if snapshot_id in self._snapshots: 

460 del self._snapshots[snapshot_id] 

461 

462 # Also delete from storage if present 

463 if self.storage: 

464 try: 

465 self.storage.delete_snapshot(snapshot_id) 

466 except (KeyError, FileNotFoundError): 

467 pass 

468 

469 return True 

470 

471 return False