Coverage for src / harnessutils / tokens / exact.py: 63%
46 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 22:41 -0600
1"""Exact token counting using tiktoken."""
3import json
4from typing import Any
6import tiktoken
8# Global encoding cache for performance
9_encoding_cache: tiktoken.Encoding | None = None
12def count_tokens_exact(
13 messages: list[dict[str, Any]], model: str = "claude-3-5-sonnet-20241022"
14) -> int:
15 """Count exact tokens for messages using tiktoken.
17 Args:
18 messages: Messages in model format
19 model: Model name for tokenizer selection
21 Returns:
22 Exact token count
23 """
25 # Map model names to tiktoken encodings
26 # Claude models use cl100k_base encoding (same as GPT-4)
27 encoding_name = "cl100k_base"
29 # For other models, use appropriate encoding
30 if "gpt-4" in model.lower():
31 encoding_name = "cl100k_base"
32 elif "gpt-3.5" in model.lower():
33 encoding_name = "cl100k_base"
34 elif "claude" in model.lower():
35 encoding_name = "cl100k_base"
36 else:
37 # Default to cl100k_base for unknown models
38 encoding_name = "cl100k_base"
40 encoding = tiktoken.get_encoding(encoding_name)
42 total_tokens = 0
44 for message in messages:
45 # Count role
46 total_tokens += len(encoding.encode(message.get("role", "")))
48 # Count content
49 content = message.get("content", "")
50 if isinstance(content, str):
51 total_tokens += len(encoding.encode(content))
52 elif isinstance(content, list):
53 # Handle structured content (parts)
54 for part in content:
55 if isinstance(part, dict):
56 if "text" in part:
57 total_tokens += len(encoding.encode(part["text"]))
58 if "type" in part:
59 total_tokens += len(encoding.encode(part["type"]))
60 # Tool use/result structures
61 if "tool_use_id" in part:
62 total_tokens += len(encoding.encode(part["tool_use_id"]))
63 if "name" in part:
64 total_tokens += len(encoding.encode(part["name"]))
65 if "input" in part:
66 # Tool input is usually a dict
67 total_tokens += len(encoding.encode(json.dumps(part["input"])))
68 if "content" in part:
69 total_tokens += len(encoding.encode(str(part["content"])))
71 # Message formatting overhead (approximate)
72 # Each message has some structural tokens
73 total_tokens += 4
75 return total_tokens
78def get_encoding() -> tiktoken.Encoding:
79 """Get cached encoding for performance.
81 Caches the encoding globally to avoid repeated loading.
82 Uses cl100k_base (GPT-4/Claude tokenizer).
84 Returns:
85 Cached tiktoken encoding
86 """
87 global _encoding_cache
88 if _encoding_cache is None:
89 _encoding_cache = tiktoken.get_encoding("cl100k_base")
90 return _encoding_cache
93def count_tokens_fast(text: str) -> int:
94 """Fast exact token counting for plain text.
96 Optimized for pruning decisions where we count many outputs.
97 Uses cached encoding to avoid repeated loading.
99 Args:
100 text: Plain text to count tokens for
102 Returns:
103 Exact token count
104 """
105 if not text:
106 return 0
108 encoding = get_encoding()
109 return len(encoding.encode(text))