pattern_lens.activations
computing and saving activations given a model and prompts
Usage:
from the command line:
python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
from a script:
from pattern_lens.activations import activations_main
activations_main(
model_name="gpt2",
save_path="demo/"
prompts_path="data/pile_1k.jsonl",
)
1"""computing and saving activations given a model and prompts 2 3 4# Usage: 5 6from the command line: 7 8```bash 9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 10``` 11 12from a script: 13 14```python 15from pattern_lens.activations import activations_main 16activations_main( 17 model_name="gpt2", 18 save_path="demo/" 19 prompts_path="data/pile_1k.jsonl", 20) 21``` 22 23""" 24 25import argparse 26import functools 27import json 28from dataclasses import asdict 29from pathlib import Path 30import re 31from typing import Callable, Literal, overload 32 33import numpy as np 34import torch 35import tqdm 36from muutils.spinner import SpinnerContext 37from muutils.misc.numerical import shorten_numerical_to_str 38from muutils.json_serialize import json_serialize 39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped] 40 41from pattern_lens.consts import ( 42 ATTN_PATTERN_REGEX, 43 DATA_DIR, 44 ActivationCacheNp, 45 SPINNER_KWARGS, 46 DIVIDER_S1, 47 DIVIDER_S2, 48) 49from pattern_lens.indexes import ( 50 generate_models_jsonl, 51 generate_prompts_jsonl, 52 write_html_index, 53) 54from pattern_lens.load_activations import ( 55 ActivationsMissingError, 56 augment_prompt_with_hash, 57 load_activations, 58) 59from pattern_lens.prompts import load_text_data 60 61 62def compute_activations( 63 prompt: dict, 64 model: HookedTransformer | None = None, 65 save_path: Path = Path(DATA_DIR), 66 return_cache: bool = True, 67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 68) -> tuple[Path, ActivationCacheNp | None]: 69 """get activations for a given model and prompt, possibly from a cache 70 71 if from a cache, prompt_meta must be passed and contain the prompt hash 72 73 # Parameters: 74 - `prompt : dict | None` 75 (defaults to `None`) 76 - `model : HookedTransformer` 77 - `save_path : Path` 78 (defaults to `Path(DATA_DIR)`) 79 - `return_cache : bool` 80 will return `None` as the second element if `False` 81 (defaults to `True`) 82 - `names_filter : Callable[[str], bool]|re.Pattern` 83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 84 (defaults to `ATTN_PATTERN_REGEX`) 85 86 # Returns: 87 - `tuple[Path, ActivationCacheNp|None]` 88 """ 89 assert model is not None, "model must be passed" 90 assert "text" in prompt, "prompt must contain 'text' key" 91 prompt_str: str = prompt["text"] 92 93 # compute or get prompt metadata 94 prompt_tokenized: list[str] = prompt.get( 95 "tokens", 96 model.tokenizer.tokenize(prompt_str), 97 ) 98 prompt.update( 99 dict( 100 n_tokens=len(prompt_tokenized), 101 tokens=prompt_tokenized, 102 ) 103 ) 104 105 # save metadata 106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 107 prompt_dir.mkdir(parents=True, exist_ok=True) 108 with open(prompt_dir / "prompt.json", "w") as f: 109 json.dump(prompt, f) 110 111 # set up names filter 112 names_filter_fn: Callable[[str], bool] 113 if isinstance(names_filter, re.Pattern): 114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 115 else: 116 names_filter_fn = names_filter 117 118 # compute activations 119 with torch.no_grad(): 120 # TODO: batching? 121 _, cache = model.run_with_cache( 122 prompt_str, 123 names_filter=names_filter_fn, 124 return_type=None, 125 ) 126 127 cache_np: ActivationCacheNp = { 128 k: v.detach().cpu().numpy() for k, v in cache.items() 129 } 130 131 # save activations 132 activations_path: Path = prompt_dir / "activations.npz" 133 np.savez_compressed( 134 activations_path, 135 **cache_np, 136 ) 137 138 # return path and cache 139 if return_cache: 140 return activations_path, cache_np 141 else: 142 return activations_path, None 143 144 145@overload 146def get_activations( 147 prompt: dict, 148 model: HookedTransformer | str, 149 save_path: Path = Path(DATA_DIR), 150 allow_disk_cache: bool = True, 151 return_cache: Literal[False] = False, 152) -> tuple[Path, None]: ... 153@overload 154def get_activations( 155 prompt: dict, 156 model: HookedTransformer | str, 157 save_path: Path = Path(DATA_DIR), 158 allow_disk_cache: bool = True, 159 return_cache: Literal[True] = True, 160) -> tuple[Path, ActivationCacheNp]: ... 161def get_activations( 162 prompt: dict, 163 model: HookedTransformer | str, 164 save_path: Path = Path(DATA_DIR), 165 allow_disk_cache: bool = True, 166 return_cache: bool = True, 167) -> tuple[Path, ActivationCacheNp | None]: 168 """given a prompt and a model, save or load activations 169 170 # Parameters: 171 - `prompt : dict` 172 expected to contain the 'text' key 173 - `model : HookedTransformer | str` 174 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 175 - `save_path : Path` 176 path to save the activations to (and load from) 177 (defaults to `Path(DATA_DIR)`) 178 - `allow_disk_cache : bool` 179 whether to allow loading from disk cache 180 (defaults to `True`) 181 - `return_cache : bool` 182 whether to return the cache. if `False`, will return `None` as the second element 183 (defaults to `True`) 184 185 # Returns: 186 - `tuple[Path, ActivationCacheNp | None]` 187 the path to the activations and the cache if `return_cache` is `True` 188 189 """ 190 # add hash to prompt 191 augment_prompt_with_hash(prompt) 192 193 # get the model 194 model_name: str = ( 195 model.model_name if isinstance(model, HookedTransformer) else model 196 ) 197 198 # from cache 199 if allow_disk_cache: 200 try: 201 path, cache = load_activations( 202 model_name=model_name, 203 prompt=prompt, 204 save_path=save_path, 205 ) 206 if return_cache: 207 return path, cache 208 else: 209 return path, None 210 except ActivationsMissingError: 211 pass 212 213 # compute them 214 if isinstance(model, str): 215 model = HookedTransformer.from_pretrained(model_name) 216 217 return compute_activations( 218 prompt=prompt, 219 model=model, 220 save_path=save_path, 221 return_cache=True, 222 ) 223 224 225def activations_main( 226 model_name: str, 227 save_path: str, 228 prompts_path: str, 229 raw_prompts: bool, 230 min_chars: int, 231 max_chars: int, 232 force: bool, 233 n_samples: int, 234 no_index_html: bool, 235 shuffle: bool = False, 236) -> None: 237 """main function for computing activations 238 239 # Parameters: 240 - `model_name : str` 241 name of a model to load with `HookedTransformer.from_pretrained` 242 - `save_path : str` 243 path to save the activations to 244 - `prompts_path : str` 245 path to the prompts file 246 - `raw_prompts : bool` 247 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 248 - `min_chars : int` 249 minimum number of characters for a prompt 250 - `max_chars : int` 251 maximum number of characters for a prompt 252 - `force : bool` 253 whether to overwrite existing files 254 - `n_samples : int` 255 maximum number of samples to process 256 - `no_index_html : bool` 257 whether to write an index.html file 258 - `shuffle : bool` 259 whether to shuffle the prompts 260 (defaults to `False`) 261 """ 262 263 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 264 model: HookedTransformer = HookedTransformer.from_pretrained(model_name) 265 model.model_name = model_name 266 model.cfg.model_name = model_name 267 n_params: int = sum(p.numel() for p in model.parameters()) 268 print( 269 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 270 ) 271 272 save_path_p: Path = Path(save_path) 273 save_path_p.mkdir(parents=True, exist_ok=True) 274 model_path: Path = save_path_p / model_name 275 with SpinnerContext( 276 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 277 ): 278 model_cfg: HookedTransformerConfig 279 model_cfg = model.cfg 280 model_path.mkdir(parents=True, exist_ok=True) 281 with open(model_path / "model_cfg.json", "w") as f: 282 json.dump(json_serialize(asdict(model_cfg)), f) 283 284 # load prompts 285 with SpinnerContext( 286 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 287 ): 288 prompts: list[dict] 289 if raw_prompts: 290 prompts = load_text_data( 291 Path(prompts_path), 292 min_chars=min_chars, 293 max_chars=max_chars, 294 shuffle=shuffle, 295 ) 296 else: 297 with open(model_path / "prompts.jsonl", "r") as f: 298 prompts = [json.loads(line) for line in f.readlines()] 299 # truncate to n_samples 300 prompts = prompts[:n_samples] 301 302 print(f"{len(prompts)} prompts loaded") 303 304 # write index.html 305 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 306 if not no_index_html: 307 write_html_index(save_path_p) 308 309 # get activations 310 list( 311 tqdm.tqdm( 312 map( 313 functools.partial( 314 get_activations, 315 model=model, 316 save_path=save_path_p, 317 allow_disk_cache=not force, 318 return_cache=False, 319 ), 320 prompts, 321 ), 322 total=len(prompts), 323 desc="Computing activations", 324 unit="prompt", 325 ) 326 ) 327 328 with SpinnerContext( 329 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 330 ): 331 generate_models_jsonl(save_path_p) 332 generate_prompts_jsonl(save_path_p / model_name) 333 334 335def main(): 336 print(DIVIDER_S1) 337 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 338 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 339 # input and output 340 arg_parser.add_argument( 341 "--model", 342 "-m", 343 type=str, 344 required=True, 345 help="The model name(s) to use. comma separated with no whitespace if multiple", 346 ) 347 348 arg_parser.add_argument( 349 "--prompts", 350 "-p", 351 type=str, 352 required=False, 353 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 354 default=None, 355 ) 356 357 arg_parser.add_argument( 358 "--save-path", 359 "-s", 360 type=str, 361 required=False, 362 help="The path to save the attention patterns", 363 default=DATA_DIR, 364 ) 365 366 # min and max prompt lengths 367 arg_parser.add_argument( 368 "--min-chars", 369 type=int, 370 required=False, 371 help="The minimum number of characters for a prompt", 372 default=100, 373 ) 374 arg_parser.add_argument( 375 "--max-chars", 376 type=int, 377 required=False, 378 help="The maximum number of characters for a prompt", 379 default=1000, 380 ) 381 382 # number of samples 383 arg_parser.add_argument( 384 "--n-samples", 385 "-n", 386 type=int, 387 required=False, 388 help="The max number of samples to process, do all in the file if None", 389 default=None, 390 ) 391 392 # force overwrite 393 arg_parser.add_argument( 394 "--force", 395 "-f", 396 action="store_true", 397 help="If passed, will overwrite existing files", 398 ) 399 400 # no index html 401 arg_parser.add_argument( 402 "--no-index-html", 403 action="store_true", 404 help="If passed, will not write an index.html file for the model", 405 ) 406 407 # raw prompts 408 arg_parser.add_argument( 409 "--raw-prompts", 410 "-r", 411 action="store_true", 412 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 413 ) 414 415 # shuffle 416 arg_parser.add_argument( 417 "--shuffle", 418 action="store_true", 419 help="If passed, will shuffle the prompts", 420 ) 421 422 args: argparse.Namespace = arg_parser.parse_args() 423 424 print(f"args parsed: {args}") 425 426 models: list[str] 427 if "," in args.model: 428 models = args.model.split(",") 429 else: 430 models = [args.model] 431 432 n_models: int = len(models) 433 for idx, model in enumerate(models): 434 print(DIVIDER_S2) 435 print(f"processing model {idx+1} / {n_models}: {model}") 436 print(DIVIDER_S2) 437 438 activations_main( 439 model_name=model, 440 save_path=args.save_path, 441 prompts_path=args.prompts, 442 raw_prompts=args.raw_prompts, 443 min_chars=args.min_chars, 444 max_chars=args.max_chars, 445 force=args.force, 446 n_samples=args.n_samples, 447 no_index_html=args.no_index_html, 448 shuffle=args.shuffle, 449 ) 450 451 print(DIVIDER_S1) 452 453 454if __name__ == "__main__": 455 main()
def
compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = WindowsPath('attn_data'), return_cache: bool = True, names_filter: Union[Callable[[str], bool], re.Pattern] = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
63def compute_activations( 64 prompt: dict, 65 model: HookedTransformer | None = None, 66 save_path: Path = Path(DATA_DIR), 67 return_cache: bool = True, 68 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 69) -> tuple[Path, ActivationCacheNp | None]: 70 """get activations for a given model and prompt, possibly from a cache 71 72 if from a cache, prompt_meta must be passed and contain the prompt hash 73 74 # Parameters: 75 - `prompt : dict | None` 76 (defaults to `None`) 77 - `model : HookedTransformer` 78 - `save_path : Path` 79 (defaults to `Path(DATA_DIR)`) 80 - `return_cache : bool` 81 will return `None` as the second element if `False` 82 (defaults to `True`) 83 - `names_filter : Callable[[str], bool]|re.Pattern` 84 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 85 (defaults to `ATTN_PATTERN_REGEX`) 86 87 # Returns: 88 - `tuple[Path, ActivationCacheNp|None]` 89 """ 90 assert model is not None, "model must be passed" 91 assert "text" in prompt, "prompt must contain 'text' key" 92 prompt_str: str = prompt["text"] 93 94 # compute or get prompt metadata 95 prompt_tokenized: list[str] = prompt.get( 96 "tokens", 97 model.tokenizer.tokenize(prompt_str), 98 ) 99 prompt.update( 100 dict( 101 n_tokens=len(prompt_tokenized), 102 tokens=prompt_tokenized, 103 ) 104 ) 105 106 # save metadata 107 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 108 prompt_dir.mkdir(parents=True, exist_ok=True) 109 with open(prompt_dir / "prompt.json", "w") as f: 110 json.dump(prompt, f) 111 112 # set up names filter 113 names_filter_fn: Callable[[str], bool] 114 if isinstance(names_filter, re.Pattern): 115 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 116 else: 117 names_filter_fn = names_filter 118 119 # compute activations 120 with torch.no_grad(): 121 # TODO: batching? 122 _, cache = model.run_with_cache( 123 prompt_str, 124 names_filter=names_filter_fn, 125 return_type=None, 126 ) 127 128 cache_np: ActivationCacheNp = { 129 k: v.detach().cpu().numpy() for k, v in cache.items() 130 } 131 132 # save activations 133 activations_path: Path = prompt_dir / "activations.npz" 134 np.savez_compressed( 135 activations_path, 136 **cache_np, 137 ) 138 139 # return path and cache 140 if return_cache: 141 return activations_path, cache_np 142 else: 143 return activations_path, None
get activations for a given model and prompt, possibly from a cache
if from a cache, prompt_meta must be passed and contain the prompt hash
Parameters:
prompt : dict | None(defaults toNone)model : HookedTransformersave_path : Path(defaults toPath(DATA_DIR))return_cache : boolwill returnNoneas the second element ifFalse(defaults toTrue)names_filter : Callable[[str], bool]|re.Patterna filter for the names of the activations to return. if anre.Pattern, will uselambda key: names_filter.match(key) is not None(defaults toATTN_PATTERN_REGEX)
Returns:
tuple[Path, ActivationCacheNp|None]
def
get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = WindowsPath('attn_data'), allow_disk_cache: bool = True, return_cache: bool = True) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
162def get_activations( 163 prompt: dict, 164 model: HookedTransformer | str, 165 save_path: Path = Path(DATA_DIR), 166 allow_disk_cache: bool = True, 167 return_cache: bool = True, 168) -> tuple[Path, ActivationCacheNp | None]: 169 """given a prompt and a model, save or load activations 170 171 # Parameters: 172 - `prompt : dict` 173 expected to contain the 'text' key 174 - `model : HookedTransformer | str` 175 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 176 - `save_path : Path` 177 path to save the activations to (and load from) 178 (defaults to `Path(DATA_DIR)`) 179 - `allow_disk_cache : bool` 180 whether to allow loading from disk cache 181 (defaults to `True`) 182 - `return_cache : bool` 183 whether to return the cache. if `False`, will return `None` as the second element 184 (defaults to `True`) 185 186 # Returns: 187 - `tuple[Path, ActivationCacheNp | None]` 188 the path to the activations and the cache if `return_cache` is `True` 189 190 """ 191 # add hash to prompt 192 augment_prompt_with_hash(prompt) 193 194 # get the model 195 model_name: str = ( 196 model.model_name if isinstance(model, HookedTransformer) else model 197 ) 198 199 # from cache 200 if allow_disk_cache: 201 try: 202 path, cache = load_activations( 203 model_name=model_name, 204 prompt=prompt, 205 save_path=save_path, 206 ) 207 if return_cache: 208 return path, cache 209 else: 210 return path, None 211 except ActivationsMissingError: 212 pass 213 214 # compute them 215 if isinstance(model, str): 216 model = HookedTransformer.from_pretrained(model_name) 217 218 return compute_activations( 219 prompt=prompt, 220 model=model, 221 save_path=save_path, 222 return_cache=True, 223 )
given a prompt and a model, save or load activations
Parameters:
prompt : dictexpected to contain the 'text' keymodel : HookedTransformer | streither aHookedTransformeror a string model name, to be loaded withHookedTransformer.from_pretrainedsave_path : Pathpath to save the activations to (and load from) (defaults toPath(DATA_DIR))allow_disk_cache : boolwhether to allow loading from disk cache (defaults toTrue)return_cache : boolwhether to return the cache. ifFalse, will returnNoneas the second element (defaults toTrue)
Returns:
tuple[Path, ActivationCacheNp | None]the path to the activations and the cache ifreturn_cacheisTrue
def
activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False) -> None:
226def activations_main( 227 model_name: str, 228 save_path: str, 229 prompts_path: str, 230 raw_prompts: bool, 231 min_chars: int, 232 max_chars: int, 233 force: bool, 234 n_samples: int, 235 no_index_html: bool, 236 shuffle: bool = False, 237) -> None: 238 """main function for computing activations 239 240 # Parameters: 241 - `model_name : str` 242 name of a model to load with `HookedTransformer.from_pretrained` 243 - `save_path : str` 244 path to save the activations to 245 - `prompts_path : str` 246 path to the prompts file 247 - `raw_prompts : bool` 248 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 249 - `min_chars : int` 250 minimum number of characters for a prompt 251 - `max_chars : int` 252 maximum number of characters for a prompt 253 - `force : bool` 254 whether to overwrite existing files 255 - `n_samples : int` 256 maximum number of samples to process 257 - `no_index_html : bool` 258 whether to write an index.html file 259 - `shuffle : bool` 260 whether to shuffle the prompts 261 (defaults to `False`) 262 """ 263 264 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 265 model: HookedTransformer = HookedTransformer.from_pretrained(model_name) 266 model.model_name = model_name 267 model.cfg.model_name = model_name 268 n_params: int = sum(p.numel() for p in model.parameters()) 269 print( 270 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 271 ) 272 273 save_path_p: Path = Path(save_path) 274 save_path_p.mkdir(parents=True, exist_ok=True) 275 model_path: Path = save_path_p / model_name 276 with SpinnerContext( 277 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 278 ): 279 model_cfg: HookedTransformerConfig 280 model_cfg = model.cfg 281 model_path.mkdir(parents=True, exist_ok=True) 282 with open(model_path / "model_cfg.json", "w") as f: 283 json.dump(json_serialize(asdict(model_cfg)), f) 284 285 # load prompts 286 with SpinnerContext( 287 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 288 ): 289 prompts: list[dict] 290 if raw_prompts: 291 prompts = load_text_data( 292 Path(prompts_path), 293 min_chars=min_chars, 294 max_chars=max_chars, 295 shuffle=shuffle, 296 ) 297 else: 298 with open(model_path / "prompts.jsonl", "r") as f: 299 prompts = [json.loads(line) for line in f.readlines()] 300 # truncate to n_samples 301 prompts = prompts[:n_samples] 302 303 print(f"{len(prompts)} prompts loaded") 304 305 # write index.html 306 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 307 if not no_index_html: 308 write_html_index(save_path_p) 309 310 # get activations 311 list( 312 tqdm.tqdm( 313 map( 314 functools.partial( 315 get_activations, 316 model=model, 317 save_path=save_path_p, 318 allow_disk_cache=not force, 319 return_cache=False, 320 ), 321 prompts, 322 ), 323 total=len(prompts), 324 desc="Computing activations", 325 unit="prompt", 326 ) 327 ) 328 329 with SpinnerContext( 330 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 331 ): 332 generate_models_jsonl(save_path_p) 333 generate_prompts_jsonl(save_path_p / model_name)
main function for computing activations
Parameters:
model_name : strname of a model to load withHookedTransformer.from_pretrainedsave_path : strpath to save the activations toprompts_path : strpath to the prompts fileraw_prompts : boolwhether the prompts are raw, not filtered by length.load_text_datawill be called ifTrue, otherwise just load the "text" field from each line inprompts_pathmin_chars : intminimum number of characters for a promptmax_chars : intmaximum number of characters for a promptforce : boolwhether to overwrite existing filesn_samples : intmaximum number of samples to processno_index_html : boolwhether to write an index.html fileshuffle : boolwhether to shuffle the prompts (defaults toFalse)
def
main():
336def main(): 337 print(DIVIDER_S1) 338 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 339 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 340 # input and output 341 arg_parser.add_argument( 342 "--model", 343 "-m", 344 type=str, 345 required=True, 346 help="The model name(s) to use. comma separated with no whitespace if multiple", 347 ) 348 349 arg_parser.add_argument( 350 "--prompts", 351 "-p", 352 type=str, 353 required=False, 354 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 355 default=None, 356 ) 357 358 arg_parser.add_argument( 359 "--save-path", 360 "-s", 361 type=str, 362 required=False, 363 help="The path to save the attention patterns", 364 default=DATA_DIR, 365 ) 366 367 # min and max prompt lengths 368 arg_parser.add_argument( 369 "--min-chars", 370 type=int, 371 required=False, 372 help="The minimum number of characters for a prompt", 373 default=100, 374 ) 375 arg_parser.add_argument( 376 "--max-chars", 377 type=int, 378 required=False, 379 help="The maximum number of characters for a prompt", 380 default=1000, 381 ) 382 383 # number of samples 384 arg_parser.add_argument( 385 "--n-samples", 386 "-n", 387 type=int, 388 required=False, 389 help="The max number of samples to process, do all in the file if None", 390 default=None, 391 ) 392 393 # force overwrite 394 arg_parser.add_argument( 395 "--force", 396 "-f", 397 action="store_true", 398 help="If passed, will overwrite existing files", 399 ) 400 401 # no index html 402 arg_parser.add_argument( 403 "--no-index-html", 404 action="store_true", 405 help="If passed, will not write an index.html file for the model", 406 ) 407 408 # raw prompts 409 arg_parser.add_argument( 410 "--raw-prompts", 411 "-r", 412 action="store_true", 413 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 414 ) 415 416 # shuffle 417 arg_parser.add_argument( 418 "--shuffle", 419 action="store_true", 420 help="If passed, will shuffle the prompts", 421 ) 422 423 args: argparse.Namespace = arg_parser.parse_args() 424 425 print(f"args parsed: {args}") 426 427 models: list[str] 428 if "," in args.model: 429 models = args.model.split(",") 430 else: 431 models = [args.model] 432 433 n_models: int = len(models) 434 for idx, model in enumerate(models): 435 print(DIVIDER_S2) 436 print(f"processing model {idx+1} / {n_models}: {model}") 437 print(DIVIDER_S2) 438 439 activations_main( 440 model_name=model, 441 save_path=args.save_path, 442 prompts_path=args.prompts, 443 raw_prompts=args.raw_prompts, 444 min_chars=args.min_chars, 445 max_chars=args.max_chars, 446 force=args.force, 447 n_samples=args.n_samples, 448 no_index_html=args.no_index_html, 449 shuffle=args.shuffle, 450 ) 451 452 print(DIVIDER_S1)