Coverage for src / harnessutils / turn / state_machine.py: 89%
36 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
1"""Tool state machine for managing tool execution states.
3State transitions:
4 pending → running → completed
5 ↘ error
6"""
8from typing import Literal
10ToolStatus = Literal["pending", "running", "completed", "error"]
13def transition_state(
14 current: ToolStatus,
15 event: Literal["start", "complete", "fail"],
16) -> ToolStatus:
17 """Transition tool state based on event.
19 Args:
20 current: Current tool status
21 event: Event that occurred
23 Returns:
24 New tool status
26 Raises:
27 ValueError: If transition is invalid
28 """
29 if current == "pending":
30 if event == "start":
31 return "running"
32 raise ValueError(f"Invalid transition from pending with event {event}")
34 elif current == "running":
35 if event == "complete":
36 return "completed"
37 elif event == "fail":
38 return "error"
39 raise ValueError(f"Invalid transition from running with event {event}")
41 elif current in ("completed", "error"):
42 raise ValueError(f"Cannot transition from terminal state {current}")
44 raise ValueError(f"Unknown state {current}")
47class ToolStateMachine:
48 """Manages tool execution state transitions."""
50 def __init__(self) -> None:
51 """Initialize state machine."""
52 self.states: dict[str, ToolStatus] = {}
54 def start_tool(self, call_id: str) -> None:
55 """Mark tool as started.
57 Args:
58 call_id: Tool call identifier
59 """
60 if call_id not in self.states:
61 self.states[call_id] = "pending"
63 self.states[call_id] = transition_state(self.states[call_id], "start")
65 def complete_tool(self, call_id: str) -> None:
66 """Mark tool as completed.
68 Args:
69 call_id: Tool call identifier
70 """
71 if call_id not in self.states:
72 raise ValueError(f"Tool {call_id} not found")
74 self.states[call_id] = transition_state(self.states[call_id], "complete")
76 def fail_tool(self, call_id: str) -> None:
77 """Mark tool as failed.
79 Args:
80 call_id: Tool call identifier
81 """
82 if call_id not in self.states:
83 raise ValueError(f"Tool {call_id} not found")
85 self.states[call_id] = transition_state(self.states[call_id], "fail")
87 def get_state(self, call_id: str) -> ToolStatus:
88 """Get current state of tool.
90 Args:
91 call_id: Tool call identifier
93 Returns:
94 Current tool status
95 """
96 return self.states.get(call_id, "pending")
98 def is_terminal(self, call_id: str) -> bool:
99 """Check if tool is in terminal state.
101 Args:
102 call_id: Tool call identifier
104 Returns:
105 True if tool is completed or errored
106 """
107 state = self.get_state(call_id)
108 return state in ("completed", "error")