Coverage for src / harnessutils / tokens / exact.py: 63%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-12 22:41 -0600

1"""Exact token counting using tiktoken.""" 

2 

3import json 

4from typing import Any 

5 

6import tiktoken 

7 

8# Global encoding cache for performance 

9_encoding_cache: tiktoken.Encoding | None = None 

10 

11 

12def count_tokens_exact( 

13 messages: list[dict[str, Any]], model: str = "claude-3-5-sonnet-20241022" 

14) -> int: 

15 """Count exact tokens for messages using tiktoken. 

16 

17 Args: 

18 messages: Messages in model format 

19 model: Model name for tokenizer selection 

20 

21 Returns: 

22 Exact token count 

23 """ 

24 

25 # Map model names to tiktoken encodings 

26 # Claude models use cl100k_base encoding (same as GPT-4) 

27 encoding_name = "cl100k_base" 

28 

29 # For other models, use appropriate encoding 

30 if "gpt-4" in model.lower(): 

31 encoding_name = "cl100k_base" 

32 elif "gpt-3.5" in model.lower(): 

33 encoding_name = "cl100k_base" 

34 elif "claude" in model.lower(): 

35 encoding_name = "cl100k_base" 

36 else: 

37 # Default to cl100k_base for unknown models 

38 encoding_name = "cl100k_base" 

39 

40 encoding = tiktoken.get_encoding(encoding_name) 

41 

42 total_tokens = 0 

43 

44 for message in messages: 

45 # Count role 

46 total_tokens += len(encoding.encode(message.get("role", ""))) 

47 

48 # Count content 

49 content = message.get("content", "") 

50 if isinstance(content, str): 

51 total_tokens += len(encoding.encode(content)) 

52 elif isinstance(content, list): 

53 # Handle structured content (parts) 

54 for part in content: 

55 if isinstance(part, dict): 

56 if "text" in part: 

57 total_tokens += len(encoding.encode(part["text"])) 

58 if "type" in part: 

59 total_tokens += len(encoding.encode(part["type"])) 

60 # Tool use/result structures 

61 if "tool_use_id" in part: 

62 total_tokens += len(encoding.encode(part["tool_use_id"])) 

63 if "name" in part: 

64 total_tokens += len(encoding.encode(part["name"])) 

65 if "input" in part: 

66 # Tool input is usually a dict 

67 total_tokens += len(encoding.encode(json.dumps(part["input"]))) 

68 if "content" in part: 

69 total_tokens += len(encoding.encode(str(part["content"]))) 

70 

71 # Message formatting overhead (approximate) 

72 # Each message has some structural tokens 

73 total_tokens += 4 

74 

75 return total_tokens 

76 

77 

78def get_encoding() -> tiktoken.Encoding: 

79 """Get cached encoding for performance. 

80 

81 Caches the encoding globally to avoid repeated loading. 

82 Uses cl100k_base (GPT-4/Claude tokenizer). 

83 

84 Returns: 

85 Cached tiktoken encoding 

86 """ 

87 global _encoding_cache 

88 if _encoding_cache is None: 

89 _encoding_cache = tiktoken.get_encoding("cl100k_base") 

90 return _encoding_cache 

91 

92 

93def count_tokens_fast(text: str) -> int: 

94 """Fast exact token counting for plain text. 

95 

96 Optimized for pruning decisions where we count many outputs. 

97 Uses cached encoding to avoid repeated loading. 

98 

99 Args: 

100 text: Plain text to count tokens for 

101 

102 Returns: 

103 Exact token count 

104 """ 

105 if not text: 

106 return 0 

107 

108 encoding = get_encoding() 

109 return len(encoding.encode(text))