Source code for mindroot.lib.token_counter
import os
import json
import re
import time
import asyncio
import aiofiles
import aiofiles.os
from typing import Dict, List
from mindroot.lib.chatlog import ChatLog
[docs]
async def find_chatlog_file(log_id: str) -> str:
"""
Find a chatlog file by its log_id.
Args:
log_id: The log ID to search for
Returns:
The full path to the chatlog file if found, None otherwise
"""
chat_dir = os.environ.get('CHATLOG_DIR', 'data/chat')
# Use os.walk to search through all subdirectories
for root, dirs, files in await asyncio.to_thread(os.walk, chat_dir):
for file in files:
if file == f"chatlog_{log_id}.json":
return os.path.join(root, file)
return None
[docs]
async def get_cache_dir() -> str:
"""
Get the directory for token count cache files.
Creates the directory if it doesn't exist.
"""
cache_dir = os.environ.get('TOKEN_CACHE_DIR', 'data/token_cache')
if not await aiofiles.os.path.exists(cache_dir):
await aiofiles.os.makedirs(cache_dir)
return cache_dir
[docs]
async def get_cache_path(log_id: str) -> str:
"""
Get the path to the cache file for a specific log_id.
"""
cache_dir = await get_cache_dir()
return os.path.join(cache_dir, f"tokens_{log_id}.json")
[docs]
async def get_cached_token_counts(log_id: str, log_path: str) -> Dict[str, int]:
"""
Get cached token counts if available and valid.
Args:
log_id: The log ID
log_path: Path to the actual log file
Returns:
Cached token counts if valid, None otherwise
"""
cache_path = await get_cache_path(log_id)
# If cache doesn't exist, return None
if not await aiofiles.os.path.exists(cache_path):
return None
try:
# Get modification times
log_mtime = await aiofiles.os.path.getmtime(log_path)
cache_mtime = await aiofiles.os.path.getmtime(cache_path)
current_time = time.time()
# If log was modified after cache was created, cache is invalid
if log_mtime > cache_mtime:
return None
# Don't recalculate sooner than 3 minutes after last calculation
if current_time - cache_mtime < 180: # 3 minutes in seconds
async with aiofiles.open(cache_path, 'r') as f:
content = await f.read()
return json.loads(content)
# For logs that haven't been modified in over an hour, consider them "finished"
# and use the cache regardless of when it was last calculated
if current_time - log_mtime > 3600: # 1 hour in seconds
async with aiofiles.open(cache_path, 'r') as f:
content = await f.read()
return json.loads(content)
except (json.JSONDecodeError, IOError) as e:
print(f"Error reading token cache: {e}")
return None
[docs]
async def save_token_counts_to_cache(log_id: str, token_counts: Dict[str, int]) -> None:
"""
Save token counts to cache.
"""
cache_path = await get_cache_path(log_id)
async with aiofiles.open(cache_path, 'w') as f:
await f.write(json.dumps(token_counts))
[docs]
async def count_tokens_for_log_id(log_id: str) -> Dict[str, int]:
"""
Count tokens for a chat log identified by log_id, including any delegated tasks.
Args:
log_id: The log ID to count tokens for
Returns:
Dictionary with token counts or None if log not found
"""
# Find the chatlog file
chatlog_path = await find_chatlog_file(log_id)
if not chatlog_path:
return None
# Check cache first
cached_counts = await get_cached_token_counts(log_id, chatlog_path)
if cached_counts:
print(f"Using cached token counts for {log_id}")
return cached_counts
print(f"Calculating token counts for {log_id}")
# Load the chat log
async with aiofiles.open(chatlog_path, 'r') as f:
content = await f.read()
log_data = json.loads(content)
# Create a temporary ChatLog instance to count tokens
temp_log = ChatLog(log_id=log_id, user="system", agent=log_data.get('agent', 'unknown'))
temp_log.messages = log_data.get('messages', [])
# Count tokens for this log
parent_counts = temp_log.count_tokens()
# Create combined counts (starting with parent counts)
combined_counts = {}
combined_counts['input_tokens_sequence'] = parent_counts['input_tokens_sequence']
combined_counts['output_tokens_sequence'] = parent_counts['output_tokens_sequence']
combined_counts['input_tokens_total'] = parent_counts['input_tokens_total']
# Find delegated task log IDs
delegated_log_ids = extract_delegate_task_log_ids(temp_log.messages)
# Recursively count tokens for delegated tasks
for delegated_id in delegated_log_ids:
delegated_counts = await count_tokens_for_log_id(delegated_id)
if delegated_counts:
combined_counts['input_tokens_sequence'] += delegated_counts['input_tokens_sequence']
combined_counts['output_tokens_sequence'] += delegated_counts['output_tokens_sequence']
combined_counts['input_tokens_total'] += delegated_counts['input_tokens_total']
# Create final result with both parent and combined counts
token_counts = {}
# Parent session only counts
token_counts['input_tokens_sequence'] = parent_counts['input_tokens_sequence']
token_counts['output_tokens_sequence'] = parent_counts['output_tokens_sequence']
token_counts['input_tokens_total'] = parent_counts['input_tokens_total']
# Combined counts (parent + all subtasks)
token_counts['combined_input_tokens_sequence'] = combined_counts['input_tokens_sequence']
token_counts['combined_output_tokens_sequence'] = combined_counts['output_tokens_sequence']
token_counts['combined_input_tokens_total'] = combined_counts['input_tokens_total']
# Save to cache
await save_token_counts_to_cache(log_id, token_counts)
return token_counts