Coverage for src / harnessutils / turn / state_machine.py: 89%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-12 22:41 -0600

1"""Tool state machine for managing tool execution states. 

2 

3State transitions: 

4 pending → running → completed 

5 ↘ error 

6""" 

7 

8from typing import Literal 

9 

10ToolStatus = Literal["pending", "running", "completed", "error"] 

11 

12 

13def transition_state( 

14 current: ToolStatus, 

15 event: Literal["start", "complete", "fail"], 

16) -> ToolStatus: 

17 """Transition tool state based on event. 

18 

19 Args: 

20 current: Current tool status 

21 event: Event that occurred 

22 

23 Returns: 

24 New tool status 

25 

26 Raises: 

27 ValueError: If transition is invalid 

28 """ 

29 if current == "pending": 

30 if event == "start": 

31 return "running" 

32 raise ValueError(f"Invalid transition from pending with event {event}") 

33 

34 elif current == "running": 

35 if event == "complete": 

36 return "completed" 

37 elif event == "fail": 

38 return "error" 

39 raise ValueError(f"Invalid transition from running with event {event}") 

40 

41 elif current in ("completed", "error"): 

42 raise ValueError(f"Cannot transition from terminal state {current}") 

43 

44 raise ValueError(f"Unknown state {current}") 

45 

46 

47class ToolStateMachine: 

48 """Manages tool execution state transitions.""" 

49 

50 def __init__(self) -> None: 

51 """Initialize state machine.""" 

52 self.states: dict[str, ToolStatus] = {} 

53 

54 def start_tool(self, call_id: str) -> None: 

55 """Mark tool as started. 

56 

57 Args: 

58 call_id: Tool call identifier 

59 """ 

60 if call_id not in self.states: 

61 self.states[call_id] = "pending" 

62 

63 self.states[call_id] = transition_state(self.states[call_id], "start") 

64 

65 def complete_tool(self, call_id: str) -> None: 

66 """Mark tool as completed. 

67 

68 Args: 

69 call_id: Tool call identifier 

70 """ 

71 if call_id not in self.states: 

72 raise ValueError(f"Tool {call_id} not found") 

73 

74 self.states[call_id] = transition_state(self.states[call_id], "complete") 

75 

76 def fail_tool(self, call_id: str) -> None: 

77 """Mark tool as failed. 

78 

79 Args: 

80 call_id: Tool call identifier 

81 """ 

82 if call_id not in self.states: 

83 raise ValueError(f"Tool {call_id} not found") 

84 

85 self.states[call_id] = transition_state(self.states[call_id], "fail") 

86 

87 def get_state(self, call_id: str) -> ToolStatus: 

88 """Get current state of tool. 

89 

90 Args: 

91 call_id: Tool call identifier 

92 

93 Returns: 

94 Current tool status 

95 """ 

96 return self.states.get(call_id, "pending") 

97 

98 def is_terminal(self, call_id: str) -> bool: 

99 """Check if tool is in terminal state. 

100 

101 Args: 

102 call_id: Tool call identifier 

103 

104 Returns: 

105 True if tool is completed or errored 

106 """ 

107 state = self.get_state(call_id) 

108 return state in ("completed", "error")