Coverage for src / harnessutils / tokens / exact.py: 100%

39 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-18 09:29 -0600

1"""Exact token counting using tiktoken.""" 

2 

3import json 

4from typing import Any 

5 

6import tiktoken 

7 

8# Global encoding cache for performance 

9_encoding_cache: tiktoken.Encoding | None = None 

10 

11 

12def count_tokens_exact( 

13 messages: list[dict[str, Any]], model: str | None = None 

14) -> int: 

15 """Count exact tokens for messages using tiktoken. 

16 

17 Uses cl100k_base encoding (GPT-4/Claude tokenizer) which provides accurate 

18 counts for most modern LLMs. Model parameter is optional and primarily for 

19 future extensibility. 

20 

21 Args: 

22 messages: Messages in model format 

23 model: Optional model name (currently unused, defaults to cl100k_base) 

24 

25 Returns: 

26 Exact token count 

27 """ 

28 

29 # Use cl100k_base encoding (GPT-4/Claude/most modern LLMs) 

30 # This provides accurate token counts for: 

31 # - All Claude models (3, 3.5, 4.x) 

32 # - GPT-4 and GPT-3.5-turbo 

33 # - Most other transformer-based models 

34 encoding_name = "cl100k_base" 

35 

36 # Future: could add model-specific encodings here if needed 

37 # For now, cl100k_base is universal enough 

38 

39 encoding = tiktoken.get_encoding(encoding_name) 

40 

41 total_tokens = 0 

42 

43 for message in messages: 

44 # Count role 

45 total_tokens += len(encoding.encode(message.get("role", ""), disallowed_special=())) 

46 

47 # Count content 

48 content = message.get("content", "") 

49 if isinstance(content, str): 

50 total_tokens += len(encoding.encode(content, disallowed_special=())) 

51 elif isinstance(content, list): 

52 # Handle structured content (parts) 

53 for part in content: 

54 if isinstance(part, dict): 

55 if "text" in part: 

56 total_tokens += len(encoding.encode(part["text"], disallowed_special=())) 

57 if "type" in part: 

58 total_tokens += len(encoding.encode(part["type"], disallowed_special=())) 

59 # Tool use/result structures 

60 if "tool_use_id" in part: 

61 total_tokens += len( 

62 encoding.encode(part["tool_use_id"], disallowed_special=()) 

63 ) 

64 if "name" in part: 

65 total_tokens += len( 

66 encoding.encode(part["name"], disallowed_special=()) 

67 ) 

68 if "input" in part: 

69 # Tool input is usually a dict 

70 total_tokens += len( 

71 encoding.encode( 

72 json.dumps(part["input"]), disallowed_special=() 

73 ) 

74 ) 

75 if "content" in part: 

76 total_tokens += len( 

77 encoding.encode(str(part["content"]), disallowed_special=()) 

78 ) 

79 

80 # Message formatting overhead (approximate) 

81 # Each message has some structural tokens 

82 total_tokens += 4 

83 

84 return total_tokens 

85 

86 

87def get_encoding() -> tiktoken.Encoding: 

88 """Get cached encoding for performance. 

89 

90 Caches the encoding globally to avoid repeated loading. 

91 Uses cl100k_base (GPT-4/Claude tokenizer). 

92 

93 Returns: 

94 Cached tiktoken encoding 

95 """ 

96 global _encoding_cache 

97 if _encoding_cache is None: 

98 _encoding_cache = tiktoken.get_encoding("cl100k_base") 

99 return _encoding_cache 

100 

101 

102def count_tokens_fast(text: str) -> int: 

103 """Fast exact token counting for plain text. 

104 

105 Optimized for pruning decisions where we count many outputs. 

106 Uses cached encoding to avoid repeated loading. 

107 

108 Args: 

109 text: Plain text to count tokens for 

110 

111 Returns: 

112 Exact token count 

113 """ 

114 if not text: 

115 return 0 

116 

117 encoding = get_encoding() 

118 return len(encoding.encode(text, disallowed_special=()))