Coverage for pattern_lens\activations.py: 67%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1"""computing and saving activations given a model and prompts
4# Usage:
6from the command line:
8```bash
9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
10```
12from a script:
14```python
15from pattern_lens.activations import activations_main
16activations_main(
17 model_name="gpt2",
18 save_path="demo/"
19 prompts_path="data/pile_1k.jsonl",
20)
21```
23"""
25import argparse
26import functools
27import json
28from dataclasses import asdict
29from pathlib import Path
30import re
31from typing import Callable, Literal, overload
33import numpy as np
34import torch
35import tqdm
36from muutils.spinner import SpinnerContext
37from muutils.misc.numerical import shorten_numerical_to_str
38from muutils.json_serialize import json_serialize
39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped]
41from pattern_lens.consts import (
42 ATTN_PATTERN_REGEX,
43 DATA_DIR,
44 ActivationCacheNp,
45 SPINNER_KWARGS,
46 DIVIDER_S1,
47 DIVIDER_S2,
48)
49from pattern_lens.indexes import (
50 generate_models_jsonl,
51 generate_prompts_jsonl,
52 write_html_index,
53)
54from pattern_lens.load_activations import (
55 ActivationsMissingError,
56 augment_prompt_with_hash,
57 load_activations,
58)
59from pattern_lens.prompts import load_text_data
62def compute_activations(
63 prompt: dict,
64 model: HookedTransformer | None = None,
65 save_path: Path = Path(DATA_DIR),
66 return_cache: bool = True,
67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
68) -> tuple[Path, ActivationCacheNp | None]:
69 """get activations for a given model and prompt, possibly from a cache
71 if from a cache, prompt_meta must be passed and contain the prompt hash
73 # Parameters:
74 - `prompt : dict | None`
75 (defaults to `None`)
76 - `model : HookedTransformer`
77 - `save_path : Path`
78 (defaults to `Path(DATA_DIR)`)
79 - `return_cache : bool`
80 will return `None` as the second element if `False`
81 (defaults to `True`)
82 - `names_filter : Callable[[str], bool]|re.Pattern`
83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
84 (defaults to `ATTN_PATTERN_REGEX`)
86 # Returns:
87 - `tuple[Path, ActivationCacheNp|None]`
88 """
89 assert model is not None, "model must be passed"
90 assert "text" in prompt, "prompt must contain 'text' key"
91 prompt_str: str = prompt["text"]
93 # compute or get prompt metadata
94 prompt_tokenized: list[str] = prompt.get(
95 "tokens",
96 model.tokenizer.tokenize(prompt_str),
97 )
98 prompt.update(
99 dict(
100 n_tokens=len(prompt_tokenized),
101 tokens=prompt_tokenized,
102 )
103 )
105 # save metadata
106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
107 prompt_dir.mkdir(parents=True, exist_ok=True)
108 with open(prompt_dir / "prompt.json", "w") as f:
109 json.dump(prompt, f)
111 # set up names filter
112 names_filter_fn: Callable[[str], bool]
113 if isinstance(names_filter, re.Pattern):
114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731
115 else:
116 names_filter_fn = names_filter
118 # compute activations
119 with torch.no_grad():
120 # TODO: batching?
121 _, cache = model.run_with_cache(
122 prompt_str,
123 names_filter=names_filter_fn,
124 return_type=None,
125 )
127 cache_np: ActivationCacheNp = {
128 k: v.detach().cpu().numpy() for k, v in cache.items()
129 }
131 # save activations
132 activations_path: Path = prompt_dir / "activations.npz"
133 np.savez_compressed(
134 activations_path,
135 **cache_np,
136 )
138 # return path and cache
139 if return_cache:
140 return activations_path, cache_np
141 else:
142 return activations_path, None
145@overload
146def get_activations(
147 prompt: dict,
148 model: HookedTransformer | str,
149 save_path: Path = Path(DATA_DIR),
150 allow_disk_cache: bool = True,
151 return_cache: Literal[False] = False,
152) -> tuple[Path, None]: ...
153@overload
154def get_activations(
155 prompt: dict,
156 model: HookedTransformer | str,
157 save_path: Path = Path(DATA_DIR),
158 allow_disk_cache: bool = True,
159 return_cache: Literal[True] = True,
160) -> tuple[Path, ActivationCacheNp]: ...
161def get_activations(
162 prompt: dict,
163 model: HookedTransformer | str,
164 save_path: Path = Path(DATA_DIR),
165 allow_disk_cache: bool = True,
166 return_cache: bool = True,
167) -> tuple[Path, ActivationCacheNp | None]:
168 """given a prompt and a model, save or load activations
170 # Parameters:
171 - `prompt : dict`
172 expected to contain the 'text' key
173 - `model : HookedTransformer | str`
174 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
175 - `save_path : Path`
176 path to save the activations to (and load from)
177 (defaults to `Path(DATA_DIR)`)
178 - `allow_disk_cache : bool`
179 whether to allow loading from disk cache
180 (defaults to `True`)
181 - `return_cache : bool`
182 whether to return the cache. if `False`, will return `None` as the second element
183 (defaults to `True`)
185 # Returns:
186 - `tuple[Path, ActivationCacheNp | None]`
187 the path to the activations and the cache if `return_cache` is `True`
189 """
190 # add hash to prompt
191 augment_prompt_with_hash(prompt)
193 # get the model
194 model_name: str = (
195 model.model_name if isinstance(model, HookedTransformer) else model
196 )
198 # from cache
199 if allow_disk_cache:
200 try:
201 path, cache = load_activations(
202 model_name=model_name,
203 prompt=prompt,
204 save_path=save_path,
205 )
206 if return_cache:
207 return path, cache
208 else:
209 return path, None
210 except ActivationsMissingError:
211 pass
213 # compute them
214 if isinstance(model, str):
215 model = HookedTransformer.from_pretrained(model_name)
217 return compute_activations(
218 prompt=prompt,
219 model=model,
220 save_path=save_path,
221 return_cache=True,
222 )
225def activations_main(
226 model_name: str,
227 save_path: str,
228 prompts_path: str,
229 raw_prompts: bool,
230 min_chars: int,
231 max_chars: int,
232 force: bool,
233 n_samples: int,
234 no_index_html: bool,
235 shuffle: bool = False,
236) -> None:
237 """main function for computing activations
239 # Parameters:
240 - `model_name : str`
241 name of a model to load with `HookedTransformer.from_pretrained`
242 - `save_path : str`
243 path to save the activations to
244 - `prompts_path : str`
245 path to the prompts file
246 - `raw_prompts : bool`
247 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
248 - `min_chars : int`
249 minimum number of characters for a prompt
250 - `max_chars : int`
251 maximum number of characters for a prompt
252 - `force : bool`
253 whether to overwrite existing files
254 - `n_samples : int`
255 maximum number of samples to process
256 - `no_index_html : bool`
257 whether to write an index.html file
258 - `shuffle : bool`
259 whether to shuffle the prompts
260 (defaults to `False`)
261 """
263 with SpinnerContext(message="loading model", **SPINNER_KWARGS):
264 model: HookedTransformer = HookedTransformer.from_pretrained(model_name)
265 model.model_name = model_name
266 model.cfg.model_name = model_name
267 n_params: int = sum(p.numel() for p in model.parameters())
268 print(
269 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
270 )
272 save_path_p: Path = Path(save_path)
273 save_path_p.mkdir(parents=True, exist_ok=True)
274 model_path: Path = save_path_p / model_name
275 with SpinnerContext(
276 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
277 ):
278 model_cfg: HookedTransformerConfig
279 model_cfg = model.cfg
280 model_path.mkdir(parents=True, exist_ok=True)
281 with open(model_path / "model_cfg.json", "w") as f:
282 json.dump(json_serialize(asdict(model_cfg)), f)
284 # load prompts
285 with SpinnerContext(
286 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
287 ):
288 prompts: list[dict]
289 if raw_prompts:
290 prompts = load_text_data(
291 Path(prompts_path),
292 min_chars=min_chars,
293 max_chars=max_chars,
294 shuffle=shuffle,
295 )
296 else:
297 with open(model_path / "prompts.jsonl", "r") as f:
298 prompts = [json.loads(line) for line in f.readlines()]
299 # truncate to n_samples
300 prompts = prompts[:n_samples]
302 print(f"{len(prompts)} prompts loaded")
304 # write index.html
305 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
306 if not no_index_html:
307 write_html_index(save_path_p)
309 # get activations
310 list(
311 tqdm.tqdm(
312 map(
313 functools.partial(
314 get_activations,
315 model=model,
316 save_path=save_path_p,
317 allow_disk_cache=not force,
318 return_cache=False,
319 ),
320 prompts,
321 ),
322 total=len(prompts),
323 desc="Computing activations",
324 unit="prompt",
325 )
326 )
328 with SpinnerContext(
329 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
330 ):
331 generate_models_jsonl(save_path_p)
332 generate_prompts_jsonl(save_path_p / model_name)
335def main():
336 print(DIVIDER_S1)
337 with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
338 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
339 # input and output
340 arg_parser.add_argument(
341 "--model",
342 "-m",
343 type=str,
344 required=True,
345 help="The model name(s) to use. comma separated with no whitespace if multiple",
346 )
348 arg_parser.add_argument(
349 "--prompts",
350 "-p",
351 type=str,
352 required=False,
353 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
354 default=None,
355 )
357 arg_parser.add_argument(
358 "--save-path",
359 "-s",
360 type=str,
361 required=False,
362 help="The path to save the attention patterns",
363 default=DATA_DIR,
364 )
366 # min and max prompt lengths
367 arg_parser.add_argument(
368 "--min-chars",
369 type=int,
370 required=False,
371 help="The minimum number of characters for a prompt",
372 default=100,
373 )
374 arg_parser.add_argument(
375 "--max-chars",
376 type=int,
377 required=False,
378 help="The maximum number of characters for a prompt",
379 default=1000,
380 )
382 # number of samples
383 arg_parser.add_argument(
384 "--n-samples",
385 "-n",
386 type=int,
387 required=False,
388 help="The max number of samples to process, do all in the file if None",
389 default=None,
390 )
392 # force overwrite
393 arg_parser.add_argument(
394 "--force",
395 "-f",
396 action="store_true",
397 help="If passed, will overwrite existing files",
398 )
400 # no index html
401 arg_parser.add_argument(
402 "--no-index-html",
403 action="store_true",
404 help="If passed, will not write an index.html file for the model",
405 )
407 # raw prompts
408 arg_parser.add_argument(
409 "--raw-prompts",
410 "-r",
411 action="store_true",
412 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
413 )
415 # shuffle
416 arg_parser.add_argument(
417 "--shuffle",
418 action="store_true",
419 help="If passed, will shuffle the prompts",
420 )
422 args: argparse.Namespace = arg_parser.parse_args()
424 print(f"args parsed: {args}")
426 models: list[str]
427 if "," in args.model:
428 models = args.model.split(",")
429 else:
430 models = [args.model]
432 n_models: int = len(models)
433 for idx, model in enumerate(models):
434 print(DIVIDER_S2)
435 print(f"processing model {idx+1} / {n_models}: {model}")
436 print(DIVIDER_S2)
438 activations_main(
439 model_name=model,
440 save_path=args.save_path,
441 prompts_path=args.prompts,
442 raw_prompts=args.raw_prompts,
443 min_chars=args.min_chars,
444 max_chars=args.max_chars,
445 force=args.force,
446 n_samples=args.n_samples,
447 no_index_html=args.no_index_html,
448 shuffle=args.shuffle,
449 )
451 print(DIVIDER_S1)
454if __name__ == "__main__":
455 main()