Coverage for src / harnessutils / tokens / exact.py: 100%
39 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-18 09:29 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-18 09:29 -0600
1"""Exact token counting using tiktoken."""
3import json
4from typing import Any
6import tiktoken
8# Global encoding cache for performance
9_encoding_cache: tiktoken.Encoding | None = None
12def count_tokens_exact(
13 messages: list[dict[str, Any]], model: str | None = None
14) -> int:
15 """Count exact tokens for messages using tiktoken.
17 Uses cl100k_base encoding (GPT-4/Claude tokenizer) which provides accurate
18 counts for most modern LLMs. Model parameter is optional and primarily for
19 future extensibility.
21 Args:
22 messages: Messages in model format
23 model: Optional model name (currently unused, defaults to cl100k_base)
25 Returns:
26 Exact token count
27 """
29 # Use cl100k_base encoding (GPT-4/Claude/most modern LLMs)
30 # This provides accurate token counts for:
31 # - All Claude models (3, 3.5, 4.x)
32 # - GPT-4 and GPT-3.5-turbo
33 # - Most other transformer-based models
34 encoding_name = "cl100k_base"
36 # Future: could add model-specific encodings here if needed
37 # For now, cl100k_base is universal enough
39 encoding = tiktoken.get_encoding(encoding_name)
41 total_tokens = 0
43 for message in messages:
44 # Count role
45 total_tokens += len(encoding.encode(message.get("role", ""), disallowed_special=()))
47 # Count content
48 content = message.get("content", "")
49 if isinstance(content, str):
50 total_tokens += len(encoding.encode(content, disallowed_special=()))
51 elif isinstance(content, list):
52 # Handle structured content (parts)
53 for part in content:
54 if isinstance(part, dict):
55 if "text" in part:
56 total_tokens += len(encoding.encode(part["text"], disallowed_special=()))
57 if "type" in part:
58 total_tokens += len(encoding.encode(part["type"], disallowed_special=()))
59 # Tool use/result structures
60 if "tool_use_id" in part:
61 total_tokens += len(
62 encoding.encode(part["tool_use_id"], disallowed_special=())
63 )
64 if "name" in part:
65 total_tokens += len(
66 encoding.encode(part["name"], disallowed_special=())
67 )
68 if "input" in part:
69 # Tool input is usually a dict
70 total_tokens += len(
71 encoding.encode(
72 json.dumps(part["input"]), disallowed_special=()
73 )
74 )
75 if "content" in part:
76 total_tokens += len(
77 encoding.encode(str(part["content"]), disallowed_special=())
78 )
80 # Message formatting overhead (approximate)
81 # Each message has some structural tokens
82 total_tokens += 4
84 return total_tokens
87def get_encoding() -> tiktoken.Encoding:
88 """Get cached encoding for performance.
90 Caches the encoding globally to avoid repeated loading.
91 Uses cl100k_base (GPT-4/Claude tokenizer).
93 Returns:
94 Cached tiktoken encoding
95 """
96 global _encoding_cache
97 if _encoding_cache is None:
98 _encoding_cache = tiktoken.get_encoding("cl100k_base")
99 return _encoding_cache
102def count_tokens_fast(text: str) -> int:
103 """Fast exact token counting for plain text.
105 Optimized for pruning decisions where we count many outputs.
106 Uses cached encoding to avoid repeated loading.
108 Args:
109 text: Plain text to count tokens for
111 Returns:
112 Exact token count
113 """
114 if not text:
115 return 0
117 encoding = get_encoding()
118 return len(encoding.encode(text, disallowed_special=()))