Coverage for src / harnessutils / snapshots.py: 91%
150 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-18 10:56 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-18 10:56 -0600
1"""Context snapshot management for reproducibility and debugging.
3Snapshots capture full conversation state at a point in time, enabling:
4- Reproducible debugging (save state before/after changes)
5- A/B testing (compare different context strategies)
6- Time travel (restore to previous state)
7- Version control (export snapshots to git)
8"""
10import json
11import time
12from dataclasses import dataclass, field
13from typing import Any
15from harnessutils.models.conversation import Conversation
16from harnessutils.models.message import Message
19@dataclass
20class Snapshot:
21 """Complete snapshot of conversation state.
23 Captures everything needed to restore conversation to exact state:
24 - All messages with full content
25 - Conversation metadata (velocity, etc.)
26 - Configuration used
27 - Snapshot metadata (when, why)
28 """
30 snapshot_id: str
31 conversation_id: str
32 timestamp: int # Unix ms
33 messages: list[dict[str, Any]] # Serialized messages
34 conversation: dict[str, Any] # Serialized conversation
35 config: dict[str, Any] # Configuration snapshot
36 metadata: dict[str, Any] = field(default_factory=dict)
38 def to_dict(self) -> dict[str, Any]:
39 """Serialize snapshot to dictionary.
41 Returns:
42 Dictionary representation for JSON export
43 """
44 return {
45 "snapshot_id": self.snapshot_id,
46 "conversation_id": self.conversation_id,
47 "timestamp": self.timestamp,
48 "messages": self.messages,
49 "conversation": self.conversation,
50 "config": self.config,
51 "metadata": self.metadata,
52 }
54 @classmethod
55 def from_dict(cls, data: dict[str, Any]) -> "Snapshot":
56 """Deserialize snapshot from dictionary.
58 Args:
59 data: Dictionary representation
61 Returns:
62 Snapshot instance
63 """
64 return cls(
65 snapshot_id=data["snapshot_id"],
66 conversation_id=data["conversation_id"],
67 timestamp=data["timestamp"],
68 messages=data["messages"],
69 conversation=data["conversation"],
70 config=data["config"],
71 metadata=data.get("metadata", {}),
72 )
75@dataclass
76class SnapshotDiff:
77 """Difference between two snapshots."""
79 messages_added: int
80 messages_removed: int
81 tokens_delta: int
82 message_changes: list[dict[str, Any]]
83 config_changes: list[dict[str, Any]]
84 metadata_changes: dict[str, Any]
87class SnapshotManager:
88 """Manages conversation snapshots.
90 Handles snapshot creation, storage, restoration, and comparison.
91 """
93 def __init__(self, storage_backend: Any = None):
94 """Initialize snapshot manager.
96 Args:
97 storage_backend: Optional storage backend for persistence
98 """
99 self.storage = storage_backend
100 self._snapshots: dict[str, Snapshot] = {}
102 def create_snapshot(
103 self,
104 conversation_id: str,
105 messages: list[Message],
106 conversation: Conversation,
107 config: dict[str, Any],
108 snapshot_id: str | None = None,
109 metadata: dict[str, Any] | None = None,
110 ) -> Snapshot:
111 """Create snapshot of current conversation state.
113 Args:
114 conversation_id: Conversation ID
115 messages: Current messages
116 conversation: Conversation metadata
117 config: Configuration dict
118 snapshot_id: Optional snapshot ID (auto-generated if None)
119 metadata: Optional metadata (e.g., {"reason": "before_refactor"})
121 Returns:
122 Created snapshot
123 """
124 from harnessutils.utils.ids import generate_id
126 snapshot_id = snapshot_id or generate_id("snap")
127 now_ms = int(time.time() * 1000)
129 # Serialize messages with parts
130 from dataclasses import asdict
132 serialized_messages = []
133 for msg in messages:
134 msg_dict = msg.to_dict()
135 # Add parts - they're not in to_dict() since normally stored separately
136 msg_dict["parts"] = [asdict(part) for part in msg.parts]
137 serialized_messages.append(msg_dict)
139 # Create snapshot
140 snapshot = Snapshot(
141 snapshot_id=snapshot_id,
142 conversation_id=conversation_id,
143 timestamp=now_ms,
144 messages=serialized_messages,
145 conversation=conversation.to_dict(),
146 config=config,
147 metadata=metadata or {},
148 )
150 # Store in memory
151 self._snapshots[snapshot_id] = snapshot
153 # Optionally persist to storage
154 if self.storage:
155 self.storage.save_snapshot(snapshot_id, snapshot.to_dict())
157 return snapshot
159 def get_snapshot(self, snapshot_id: str) -> Snapshot | None:
160 """Retrieve snapshot by ID.
162 Args:
163 snapshot_id: Snapshot ID
165 Returns:
166 Snapshot if found, None otherwise
167 """
168 # Check memory first
169 if snapshot_id in self._snapshots:
170 return self._snapshots[snapshot_id]
172 # Try storage
173 if self.storage:
174 try:
175 data = self.storage.load_snapshot(snapshot_id)
176 snapshot = Snapshot.from_dict(data)
177 self._snapshots[snapshot_id] = snapshot
178 return snapshot
179 except (KeyError, FileNotFoundError):
180 pass
182 return None
184 def restore_snapshot(self, snapshot: Snapshot) -> tuple[list[Message], Conversation]:
185 """Restore messages and conversation from snapshot.
187 Args:
188 snapshot: Snapshot to restore
190 Returns:
191 Tuple of (messages, conversation)
192 """
193 from harnessutils.models.parts import (
194 CompactionPart,
195 PatchPart,
196 ReasoningPart,
197 StepFinishPart,
198 StepStartPart,
199 SubtaskPart,
200 TextPart,
201 TimeInfo,
202 ToolPart,
203 ToolState,
204 )
206 def _deserialize_part(part_data: dict[str, Any]) -> Any:
207 """Deserialize a part from dict."""
208 part_type = part_data.get("type")
210 # Extract time if present
211 time_info = None
212 if part_data.get("time"):
213 time_data = part_data["time"]
214 time_info = TimeInfo(
215 start=time_data.get("start", 0),
216 end=time_data.get("end", 0),
217 compacted=time_data.get("compacted"),
218 )
220 # Common fields
221 metadata = part_data.get("metadata", {})
223 if part_type == "text":
224 return TextPart(
225 text=part_data.get("text", ""),
226 ignored=part_data.get("ignored", False),
227 time=time_info,
228 metadata=metadata,
229 )
230 elif part_type == "reasoning":
231 return ReasoningPart(
232 text=part_data.get("text", ""),
233 time=time_info,
234 metadata=metadata,
235 )
236 elif part_type == "tool":
237 state_data = part_data.get("state", {})
238 state_time = None
239 if state_data.get("time"):
240 st = state_data["time"]
241 state_time = TimeInfo(
242 start=st.get("start", 0),
243 end=st.get("end", 0),
244 compacted=st.get("compacted"),
245 )
246 return ToolPart(
247 tool=part_data.get("tool", ""),
248 call_id=part_data.get("call_id", ""),
249 state=ToolState(
250 status=state_data.get("status", "pending"),
251 input=state_data.get("input", {}),
252 output=state_data.get("output", ""),
253 title=state_data.get("title", ""),
254 metadata=state_data.get("metadata", {}),
255 error=state_data.get("error"),
256 time=state_time,
257 attachments=state_data.get("attachments", []),
258 ),
259 time=time_info,
260 metadata=metadata,
261 )
262 elif part_type == "step-start":
263 return StepStartPart(
264 snapshot=part_data.get("snapshot", ""),
265 time=time_info,
266 metadata=metadata,
267 )
268 elif part_type == "step-finish":
269 return StepFinishPart(
270 reason=part_data.get("reason", "stop"),
271 snapshot=part_data.get("snapshot", ""),
272 tokens=part_data.get("tokens", {}),
273 cost=part_data.get("cost", 0.0),
274 time=time_info,
275 metadata=metadata,
276 )
277 elif part_type == "compaction":
278 return CompactionPart(
279 auto=part_data.get("auto", False),
280 time=time_info,
281 metadata=metadata,
282 )
283 elif part_type == "patch":
284 return PatchPart(
285 hash=part_data.get("hash", ""),
286 files=part_data.get("files", []),
287 time=time_info,
288 metadata=metadata,
289 )
290 elif part_type == "subtask":
291 return SubtaskPart(
292 prompt=part_data.get("prompt", ""),
293 description=part_data.get("description", ""),
294 agent=part_data.get("agent", ""),
295 model=part_data.get("model", {}),
296 time=time_info,
297 metadata=metadata,
298 )
299 else:
300 # Unknown part type - create a basic TextPart
301 return TextPart(text="", time=time_info, metadata=metadata)
303 # Deserialize messages
304 messages = []
305 for msg_data in snapshot.messages:
306 msg = Message.from_dict(msg_data)
307 # Restore parts
308 if "parts" in msg_data:
309 for part_data in msg_data["parts"]:
310 part = _deserialize_part(part_data)
311 msg.add_part(part)
312 messages.append(msg)
314 # Deserialize conversation
315 conversation = Conversation.from_dict(snapshot.conversation)
317 return messages, conversation
319 def compare_snapshots(
320 self, snapshot1_id: str, snapshot2_id: str
321 ) -> SnapshotDiff | None:
322 """Compare two snapshots to see what changed.
324 Args:
325 snapshot1_id: First snapshot ID (earlier)
326 snapshot2_id: Second snapshot ID (later)
328 Returns:
329 SnapshotDiff describing changes, or None if snapshots not found
330 """
331 snap1 = self.get_snapshot(snapshot1_id)
332 snap2 = self.get_snapshot(snapshot2_id)
334 if not snap1 or not snap2:
335 return None
337 # Compare message counts
338 messages_added = len(snap2.messages) - len(snap1.messages)
339 messages_removed = max(0, len(snap1.messages) - len(snap2.messages))
341 # Calculate token delta (simplified - just count message changes)
342 # In a real implementation, would calculate actual token delta
343 tokens_delta = 0
345 # Find message changes
346 message_changes = []
347 snap1_ids = {msg["id"] for msg in snap1.messages}
348 snap2_ids = {msg["id"] for msg in snap2.messages}
350 added_ids = snap2_ids - snap1_ids
351 removed_ids = snap1_ids - snap2_ids
353 if added_ids:
354 message_changes.append({
355 "type": "added",
356 "count": len(added_ids),
357 "ids": list(added_ids),
358 })
360 if removed_ids:
361 message_changes.append({
362 "type": "removed",
363 "count": len(removed_ids),
364 "ids": list(removed_ids),
365 })
367 # Compare configs (simplified)
368 config_changes = []
369 if snap1.config != snap2.config:
370 config_changes.append({
371 "type": "config_modified",
372 "before": snap1.config,
373 "after": snap2.config,
374 })
376 # Metadata changes
377 metadata_changes = {
378 "snapshot1_timestamp": snap1.timestamp,
379 "snapshot2_timestamp": snap2.timestamp,
380 "time_delta_ms": snap2.timestamp - snap1.timestamp,
381 }
383 return SnapshotDiff(
384 messages_added=max(0, messages_added),
385 messages_removed=messages_removed,
386 tokens_delta=tokens_delta,
387 message_changes=message_changes,
388 config_changes=config_changes,
389 metadata_changes=metadata_changes,
390 )
392 def export_snapshot(self, snapshot_id: str, file_path: str) -> None:
393 """Export snapshot to JSON file.
395 Useful for:
396 - Version control (commit snapshots to git)
397 - Sharing test cases
398 - Long-term archival
400 Args:
401 snapshot_id: Snapshot to export
402 file_path: Path to write JSON file
404 Raises:
405 FileNotFoundError: If snapshot not found
406 """
407 snapshot = self.get_snapshot(snapshot_id)
408 if not snapshot:
409 raise FileNotFoundError(f"Snapshot {snapshot_id} not found")
411 with open(file_path, "w") as f:
412 json.dump(snapshot.to_dict(), f, indent=2)
414 def import_snapshot(self, file_path: str) -> Snapshot:
415 """Import snapshot from JSON file.
417 Args:
418 file_path: Path to JSON file
420 Returns:
421 Imported snapshot
422 """
423 with open(file_path) as f:
424 data = json.load(f)
426 snapshot = Snapshot.from_dict(data)
427 self._snapshots[snapshot.snapshot_id] = snapshot
429 return snapshot
431 def list_snapshots(self, conversation_id: str | None = None) -> list[Snapshot]:
432 """List all snapshots, optionally filtered by conversation.
434 Args:
435 conversation_id: Optional conversation ID filter
437 Returns:
438 List of snapshots (sorted by timestamp, newest first)
439 """
440 snapshots = list(self._snapshots.values())
442 if conversation_id:
443 snapshots = [s for s in snapshots if s.conversation_id == conversation_id]
445 # Sort by timestamp, newest first
446 snapshots.sort(key=lambda s: s.timestamp, reverse=True)
448 return snapshots
450 def delete_snapshot(self, snapshot_id: str) -> bool:
451 """Delete snapshot.
453 Args:
454 snapshot_id: Snapshot to delete
456 Returns:
457 True if deleted, False if not found
458 """
459 if snapshot_id in self._snapshots:
460 del self._snapshots[snapshot_id]
462 # Also delete from storage if present
463 if self.storage:
464 try:
465 self.storage.delete_snapshot(snapshot_id)
466 except (KeyError, FileNotFoundError):
467 pass
469 return True
471 return False