docs for pattern_lens v0.1.0
View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
    model_name="gpt2",
    save_path="demo/"
    prompts_path="data/pile_1k.jsonl",
)

  1"""computing and saving activations given a model and prompts
  2
  3
  4# Usage:
  5
  6from the command line:
  7
  8```bash
  9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>
 10```
 11
 12from a script:
 13
 14```python
 15from pattern_lens.activations import activations_main
 16activations_main(
 17    model_name="gpt2",
 18    save_path="demo/"
 19    prompts_path="data/pile_1k.jsonl",
 20)
 21```
 22
 23"""
 24
 25import argparse
 26import functools
 27import json
 28from dataclasses import asdict
 29from pathlib import Path
 30import re
 31from typing import Callable, Literal, overload
 32
 33import numpy as np
 34import torch
 35import tqdm
 36from muutils.spinner import SpinnerContext
 37from muutils.misc.numerical import shorten_numerical_to_str
 38from muutils.json_serialize import json_serialize
 39from transformer_lens import HookedTransformer, HookedTransformerConfig  # type: ignore[import-untyped]
 40
 41from pattern_lens.consts import (
 42    ATTN_PATTERN_REGEX,
 43    DATA_DIR,
 44    ActivationCacheNp,
 45    SPINNER_KWARGS,
 46    DIVIDER_S1,
 47    DIVIDER_S2,
 48)
 49from pattern_lens.indexes import (
 50    generate_models_jsonl,
 51    generate_prompts_jsonl,
 52    write_html_index,
 53)
 54from pattern_lens.load_activations import (
 55    ActivationsMissingError,
 56    augment_prompt_with_hash,
 57    load_activations,
 58)
 59from pattern_lens.prompts import load_text_data
 60
 61
 62def compute_activations(
 63    prompt: dict,
 64    model: HookedTransformer | None = None,
 65    save_path: Path = Path(DATA_DIR),
 66    return_cache: bool = True,
 67    names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 68) -> tuple[Path, ActivationCacheNp | None]:
 69    """get activations for a given model and prompt, possibly from a cache
 70
 71    if from a cache, prompt_meta must be passed and contain the prompt hash
 72
 73    # Parameters:
 74     - `prompt : dict | None`
 75       (defaults to `None`)
 76     - `model : HookedTransformer`
 77     - `save_path : Path`
 78       (defaults to `Path(DATA_DIR)`)
 79     - `return_cache : bool`
 80       will return `None` as the second element if `False`
 81       (defaults to `True`)
 82     - `names_filter : Callable[[str], bool]|re.Pattern`
 83       a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
 84       (defaults to `ATTN_PATTERN_REGEX`)
 85
 86    # Returns:
 87     - `tuple[Path, ActivationCacheNp|None]`
 88    """
 89    assert model is not None, "model must be passed"
 90    assert "text" in prompt, "prompt must contain 'text' key"
 91    prompt_str: str = prompt["text"]
 92
 93    # compute or get prompt metadata
 94    prompt_tokenized: list[str] = prompt.get(
 95        "tokens",
 96        model.tokenizer.tokenize(prompt_str),
 97    )
 98    prompt.update(
 99        dict(
100            n_tokens=len(prompt_tokenized),
101            tokens=prompt_tokenized,
102        )
103    )
104
105    # save metadata
106    prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
107    prompt_dir.mkdir(parents=True, exist_ok=True)
108    with open(prompt_dir / "prompt.json", "w") as f:
109        json.dump(prompt, f)
110
111    # set up names filter
112    names_filter_fn: Callable[[str], bool]
113    if isinstance(names_filter, re.Pattern):
114        names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
115    else:
116        names_filter_fn = names_filter
117
118    # compute activations
119    with torch.no_grad():
120        # TODO: batching?
121        _, cache = model.run_with_cache(
122            prompt_str,
123            names_filter=names_filter_fn,
124            return_type=None,
125        )
126
127    cache_np: ActivationCacheNp = {
128        k: v.detach().cpu().numpy() for k, v in cache.items()
129    }
130
131    # save activations
132    activations_path: Path = prompt_dir / "activations.npz"
133    np.savez_compressed(
134        activations_path,
135        **cache_np,
136    )
137
138    # return path and cache
139    if return_cache:
140        return activations_path, cache_np
141    else:
142        return activations_path, None
143
144
145@overload
146def get_activations(
147    prompt: dict,
148    model: HookedTransformer | str,
149    save_path: Path = Path(DATA_DIR),
150    allow_disk_cache: bool = True,
151    return_cache: Literal[False] = False,
152) -> tuple[Path, None]: ...
153@overload
154def get_activations(
155    prompt: dict,
156    model: HookedTransformer | str,
157    save_path: Path = Path(DATA_DIR),
158    allow_disk_cache: bool = True,
159    return_cache: Literal[True] = True,
160) -> tuple[Path, ActivationCacheNp]: ...
161def get_activations(
162    prompt: dict,
163    model: HookedTransformer | str,
164    save_path: Path = Path(DATA_DIR),
165    allow_disk_cache: bool = True,
166    return_cache: bool = True,
167) -> tuple[Path, ActivationCacheNp | None]:
168    """given a prompt and a model, save or load activations
169
170    # Parameters:
171     - `prompt : dict`
172        expected to contain the 'text' key
173     - `model : HookedTransformer | str`
174        either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
175     - `save_path : Path`
176        path to save the activations to (and load from)
177       (defaults to `Path(DATA_DIR)`)
178     - `allow_disk_cache : bool`
179        whether to allow loading from disk cache
180       (defaults to `True`)
181     - `return_cache : bool`
182        whether to return the cache. if `False`, will return `None` as the second element
183       (defaults to `True`)
184
185    # Returns:
186     - `tuple[Path, ActivationCacheNp | None]`
187         the path to the activations and the cache if `return_cache` is `True`
188
189    """
190    # add hash to prompt
191    augment_prompt_with_hash(prompt)
192
193    # get the model
194    model_name: str = (
195        model.model_name if isinstance(model, HookedTransformer) else model
196    )
197
198    # from cache
199    if allow_disk_cache:
200        try:
201            path, cache = load_activations(
202                model_name=model_name,
203                prompt=prompt,
204                save_path=save_path,
205            )
206            if return_cache:
207                return path, cache
208            else:
209                return path, None
210        except ActivationsMissingError:
211            pass
212
213    # compute them
214    if isinstance(model, str):
215        model = HookedTransformer.from_pretrained(model_name)
216
217    return compute_activations(
218        prompt=prompt,
219        model=model,
220        save_path=save_path,
221        return_cache=True,
222    )
223
224
225def activations_main(
226    model_name: str,
227    save_path: str,
228    prompts_path: str,
229    raw_prompts: bool,
230    min_chars: int,
231    max_chars: int,
232    force: bool,
233    n_samples: int,
234    no_index_html: bool,
235    shuffle: bool = False,
236) -> None:
237    """main function for computing activations
238
239    # Parameters:
240     - `model_name : str`
241        name of a model to load with `HookedTransformer.from_pretrained`
242     - `save_path : str`
243        path to save the activations to
244     - `prompts_path : str`
245        path to the prompts file
246     - `raw_prompts : bool`
247        whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
248     - `min_chars : int`
249        minimum number of characters for a prompt
250     - `max_chars : int`
251        maximum number of characters for a prompt
252     - `force : bool`
253        whether to overwrite existing files
254     - `n_samples : int`
255        maximum number of samples to process
256     - `no_index_html : bool`
257        whether to write an index.html file
258     - `shuffle : bool`
259        whether to shuffle the prompts
260       (defaults to `False`)
261    """
262
263    with SpinnerContext(message="loading model", **SPINNER_KWARGS):
264        model: HookedTransformer = HookedTransformer.from_pretrained(model_name)
265        model.model_name = model_name
266        model.cfg.model_name = model_name
267        n_params: int = sum(p.numel() for p in model.parameters())
268    print(
269        f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
270    )
271
272    save_path_p: Path = Path(save_path)
273    save_path_p.mkdir(parents=True, exist_ok=True)
274    model_path: Path = save_path_p / model_name
275    with SpinnerContext(
276        message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
277    ):
278        model_cfg: HookedTransformerConfig
279        model_cfg = model.cfg
280        model_path.mkdir(parents=True, exist_ok=True)
281        with open(model_path / "model_cfg.json", "w") as f:
282            json.dump(json_serialize(asdict(model_cfg)), f)
283
284    # load prompts
285    with SpinnerContext(
286        message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
287    ):
288        prompts: list[dict]
289        if raw_prompts:
290            prompts = load_text_data(
291                Path(prompts_path),
292                min_chars=min_chars,
293                max_chars=max_chars,
294                shuffle=shuffle,
295            )
296        else:
297            with open(model_path / "prompts.jsonl", "r") as f:
298                prompts = [json.loads(line) for line in f.readlines()]
299        # truncate to n_samples
300        prompts = prompts[:n_samples]
301
302    print(f"{len(prompts)} prompts loaded")
303
304    # write index.html
305    with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
306        if not no_index_html:
307            write_html_index(save_path_p)
308
309    # get activations
310    list(
311        tqdm.tqdm(
312            map(
313                functools.partial(
314                    get_activations,
315                    model=model,
316                    save_path=save_path_p,
317                    allow_disk_cache=not force,
318                    return_cache=False,
319                ),
320                prompts,
321            ),
322            total=len(prompts),
323            desc="Computing activations",
324            unit="prompt",
325        )
326    )
327
328    with SpinnerContext(
329        message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
330    ):
331        generate_models_jsonl(save_path_p)
332        generate_prompts_jsonl(save_path_p / model_name)
333
334
335def main():
336    print(DIVIDER_S1)
337    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
338        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
339        # input and output
340        arg_parser.add_argument(
341            "--model",
342            "-m",
343            type=str,
344            required=True,
345            help="The model name(s) to use. comma separated with no whitespace if multiple",
346        )
347
348        arg_parser.add_argument(
349            "--prompts",
350            "-p",
351            type=str,
352            required=False,
353            help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
354            default=None,
355        )
356
357        arg_parser.add_argument(
358            "--save-path",
359            "-s",
360            type=str,
361            required=False,
362            help="The path to save the attention patterns",
363            default=DATA_DIR,
364        )
365
366        # min and max prompt lengths
367        arg_parser.add_argument(
368            "--min-chars",
369            type=int,
370            required=False,
371            help="The minimum number of characters for a prompt",
372            default=100,
373        )
374        arg_parser.add_argument(
375            "--max-chars",
376            type=int,
377            required=False,
378            help="The maximum number of characters for a prompt",
379            default=1000,
380        )
381
382        # number of samples
383        arg_parser.add_argument(
384            "--n-samples",
385            "-n",
386            type=int,
387            required=False,
388            help="The max number of samples to process, do all in the file if None",
389            default=None,
390        )
391
392        # force overwrite
393        arg_parser.add_argument(
394            "--force",
395            "-f",
396            action="store_true",
397            help="If passed, will overwrite existing files",
398        )
399
400        # no index html
401        arg_parser.add_argument(
402            "--no-index-html",
403            action="store_true",
404            help="If passed, will not write an index.html file for the model",
405        )
406
407        # raw prompts
408        arg_parser.add_argument(
409            "--raw-prompts",
410            "-r",
411            action="store_true",
412            help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
413        )
414
415        # shuffle
416        arg_parser.add_argument(
417            "--shuffle",
418            action="store_true",
419            help="If passed, will shuffle the prompts",
420        )
421
422        args: argparse.Namespace = arg_parser.parse_args()
423
424    print(f"args parsed: {args}")
425
426    models: list[str]
427    if "," in args.model:
428        models = args.model.split(",")
429    else:
430        models = [args.model]
431
432    n_models: int = len(models)
433    for idx, model in enumerate(models):
434        print(DIVIDER_S2)
435        print(f"processing model {idx+1} / {n_models}: {model}")
436        print(DIVIDER_S2)
437
438        activations_main(
439            model_name=model,
440            save_path=args.save_path,
441            prompts_path=args.prompts,
442            raw_prompts=args.raw_prompts,
443            min_chars=args.min_chars,
444            max_chars=args.max_chars,
445            force=args.force,
446            n_samples=args.n_samples,
447            no_index_html=args.no_index_html,
448            shuffle=args.shuffle,
449        )
450
451    print(DIVIDER_S1)
452
453
454if __name__ == "__main__":
455    main()

def compute_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | None = None, save_path: pathlib.Path = WindowsPath('attn_data'), return_cache: bool = True, names_filter: Union[Callable[[str], bool], re.Pattern] = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
 63def compute_activations(
 64    prompt: dict,
 65    model: HookedTransformer | None = None,
 66    save_path: Path = Path(DATA_DIR),
 67    return_cache: bool = True,
 68    names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX,
 69) -> tuple[Path, ActivationCacheNp | None]:
 70    """get activations for a given model and prompt, possibly from a cache
 71
 72    if from a cache, prompt_meta must be passed and contain the prompt hash
 73
 74    # Parameters:
 75     - `prompt : dict | None`
 76       (defaults to `None`)
 77     - `model : HookedTransformer`
 78     - `save_path : Path`
 79       (defaults to `Path(DATA_DIR)`)
 80     - `return_cache : bool`
 81       will return `None` as the second element if `False`
 82       (defaults to `True`)
 83     - `names_filter : Callable[[str], bool]|re.Pattern`
 84       a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None`
 85       (defaults to `ATTN_PATTERN_REGEX`)
 86
 87    # Returns:
 88     - `tuple[Path, ActivationCacheNp|None]`
 89    """
 90    assert model is not None, "model must be passed"
 91    assert "text" in prompt, "prompt must contain 'text' key"
 92    prompt_str: str = prompt["text"]
 93
 94    # compute or get prompt metadata
 95    prompt_tokenized: list[str] = prompt.get(
 96        "tokens",
 97        model.tokenizer.tokenize(prompt_str),
 98    )
 99    prompt.update(
100        dict(
101            n_tokens=len(prompt_tokenized),
102            tokens=prompt_tokenized,
103        )
104    )
105
106    # save metadata
107    prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"]
108    prompt_dir.mkdir(parents=True, exist_ok=True)
109    with open(prompt_dir / "prompt.json", "w") as f:
110        json.dump(prompt, f)
111
112    # set up names filter
113    names_filter_fn: Callable[[str], bool]
114    if isinstance(names_filter, re.Pattern):
115        names_filter_fn = lambda key: names_filter.match(key) is not None  # noqa: E731
116    else:
117        names_filter_fn = names_filter
118
119    # compute activations
120    with torch.no_grad():
121        # TODO: batching?
122        _, cache = model.run_with_cache(
123            prompt_str,
124            names_filter=names_filter_fn,
125            return_type=None,
126        )
127
128    cache_np: ActivationCacheNp = {
129        k: v.detach().cpu().numpy() for k, v in cache.items()
130    }
131
132    # save activations
133    activations_path: Path = prompt_dir / "activations.npz"
134    np.savez_compressed(
135        activations_path,
136        **cache_np,
137    )
138
139    # return path and cache
140    if return_cache:
141        return activations_path, cache_np
142    else:
143        return activations_path, None

get activations for a given model and prompt, possibly from a cache

if from a cache, prompt_meta must be passed and contain the prompt hash

Parameters:

  • prompt : dict | None (defaults to None)
  • model : HookedTransformer
  • save_path : Path (defaults to Path(DATA_DIR))
  • return_cache : bool will return None as the second element if False (defaults to True)
  • names_filter : Callable[[str], bool]|re.Pattern a filter for the names of the activations to return. if an re.Pattern, will use lambda key: names_filter.match(key) is not None (defaults to ATTN_PATTERN_REGEX)

Returns:

  • tuple[Path, ActivationCacheNp|None]
def get_activations( prompt: dict, model: transformer_lens.HookedTransformer.HookedTransformer | str, save_path: pathlib.Path = WindowsPath('attn_data'), allow_disk_cache: bool = True, return_cache: bool = True) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]:
162def get_activations(
163    prompt: dict,
164    model: HookedTransformer | str,
165    save_path: Path = Path(DATA_DIR),
166    allow_disk_cache: bool = True,
167    return_cache: bool = True,
168) -> tuple[Path, ActivationCacheNp | None]:
169    """given a prompt and a model, save or load activations
170
171    # Parameters:
172     - `prompt : dict`
173        expected to contain the 'text' key
174     - `model : HookedTransformer | str`
175        either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained`
176     - `save_path : Path`
177        path to save the activations to (and load from)
178       (defaults to `Path(DATA_DIR)`)
179     - `allow_disk_cache : bool`
180        whether to allow loading from disk cache
181       (defaults to `True`)
182     - `return_cache : bool`
183        whether to return the cache. if `False`, will return `None` as the second element
184       (defaults to `True`)
185
186    # Returns:
187     - `tuple[Path, ActivationCacheNp | None]`
188         the path to the activations and the cache if `return_cache` is `True`
189
190    """
191    # add hash to prompt
192    augment_prompt_with_hash(prompt)
193
194    # get the model
195    model_name: str = (
196        model.model_name if isinstance(model, HookedTransformer) else model
197    )
198
199    # from cache
200    if allow_disk_cache:
201        try:
202            path, cache = load_activations(
203                model_name=model_name,
204                prompt=prompt,
205                save_path=save_path,
206            )
207            if return_cache:
208                return path, cache
209            else:
210                return path, None
211        except ActivationsMissingError:
212            pass
213
214    # compute them
215    if isinstance(model, str):
216        model = HookedTransformer.from_pretrained(model_name)
217
218    return compute_activations(
219        prompt=prompt,
220        model=model,
221        save_path=save_path,
222        return_cache=True,
223    )

given a prompt and a model, save or load activations

Parameters:

  • prompt : dict expected to contain the 'text' key
  • model : HookedTransformer | str either a HookedTransformer or a string model name, to be loaded with HookedTransformer.from_pretrained
  • save_path : Path path to save the activations to (and load from) (defaults to Path(DATA_DIR))
  • allow_disk_cache : bool whether to allow loading from disk cache (defaults to True)
  • return_cache : bool whether to return the cache. if False, will return None as the second element (defaults to True)

Returns:

  • tuple[Path, ActivationCacheNp | None] the path to the activations and the cache if return_cache is True
def activations_main( model_name: str, save_path: str, prompts_path: str, raw_prompts: bool, min_chars: int, max_chars: int, force: bool, n_samples: int, no_index_html: bool, shuffle: bool = False) -> None:
226def activations_main(
227    model_name: str,
228    save_path: str,
229    prompts_path: str,
230    raw_prompts: bool,
231    min_chars: int,
232    max_chars: int,
233    force: bool,
234    n_samples: int,
235    no_index_html: bool,
236    shuffle: bool = False,
237) -> None:
238    """main function for computing activations
239
240    # Parameters:
241     - `model_name : str`
242        name of a model to load with `HookedTransformer.from_pretrained`
243     - `save_path : str`
244        path to save the activations to
245     - `prompts_path : str`
246        path to the prompts file
247     - `raw_prompts : bool`
248        whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path`
249     - `min_chars : int`
250        minimum number of characters for a prompt
251     - `max_chars : int`
252        maximum number of characters for a prompt
253     - `force : bool`
254        whether to overwrite existing files
255     - `n_samples : int`
256        maximum number of samples to process
257     - `no_index_html : bool`
258        whether to write an index.html file
259     - `shuffle : bool`
260        whether to shuffle the prompts
261       (defaults to `False`)
262    """
263
264    with SpinnerContext(message="loading model", **SPINNER_KWARGS):
265        model: HookedTransformer = HookedTransformer.from_pretrained(model_name)
266        model.model_name = model_name
267        model.cfg.model_name = model_name
268        n_params: int = sum(p.numel() for p in model.parameters())
269    print(
270        f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters"
271    )
272
273    save_path_p: Path = Path(save_path)
274    save_path_p.mkdir(parents=True, exist_ok=True)
275    model_path: Path = save_path_p / model_name
276    with SpinnerContext(
277        message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS
278    ):
279        model_cfg: HookedTransformerConfig
280        model_cfg = model.cfg
281        model_path.mkdir(parents=True, exist_ok=True)
282        with open(model_path / "model_cfg.json", "w") as f:
283            json.dump(json_serialize(asdict(model_cfg)), f)
284
285    # load prompts
286    with SpinnerContext(
287        message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS
288    ):
289        prompts: list[dict]
290        if raw_prompts:
291            prompts = load_text_data(
292                Path(prompts_path),
293                min_chars=min_chars,
294                max_chars=max_chars,
295                shuffle=shuffle,
296            )
297        else:
298            with open(model_path / "prompts.jsonl", "r") as f:
299                prompts = [json.loads(line) for line in f.readlines()]
300        # truncate to n_samples
301        prompts = prompts[:n_samples]
302
303    print(f"{len(prompts)} prompts loaded")
304
305    # write index.html
306    with SpinnerContext(message="writing index.html", **SPINNER_KWARGS):
307        if not no_index_html:
308            write_html_index(save_path_p)
309
310    # get activations
311    list(
312        tqdm.tqdm(
313            map(
314                functools.partial(
315                    get_activations,
316                    model=model,
317                    save_path=save_path_p,
318                    allow_disk_cache=not force,
319                    return_cache=False,
320                ),
321                prompts,
322            ),
323            total=len(prompts),
324            desc="Computing activations",
325            unit="prompt",
326        )
327    )
328
329    with SpinnerContext(
330        message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS
331    ):
332        generate_models_jsonl(save_path_p)
333        generate_prompts_jsonl(save_path_p / model_name)

main function for computing activations

Parameters:

  • model_name : str name of a model to load with HookedTransformer.from_pretrained
  • save_path : str path to save the activations to
  • prompts_path : str path to the prompts file
  • raw_prompts : bool whether the prompts are raw, not filtered by length. load_text_data will be called if True, otherwise just load the "text" field from each line in prompts_path
  • min_chars : int minimum number of characters for a prompt
  • max_chars : int maximum number of characters for a prompt
  • force : bool whether to overwrite existing files
  • n_samples : int maximum number of samples to process
  • no_index_html : bool whether to write an index.html file
  • shuffle : bool whether to shuffle the prompts (defaults to False)
def main():
336def main():
337    print(DIVIDER_S1)
338    with SpinnerContext(message="parsing args", **SPINNER_KWARGS):
339        arg_parser: argparse.ArgumentParser = argparse.ArgumentParser()
340        # input and output
341        arg_parser.add_argument(
342            "--model",
343            "-m",
344            type=str,
345            required=True,
346            help="The model name(s) to use. comma separated with no whitespace if multiple",
347        )
348
349        arg_parser.add_argument(
350            "--prompts",
351            "-p",
352            type=str,
353            required=False,
354            help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory",
355            default=None,
356        )
357
358        arg_parser.add_argument(
359            "--save-path",
360            "-s",
361            type=str,
362            required=False,
363            help="The path to save the attention patterns",
364            default=DATA_DIR,
365        )
366
367        # min and max prompt lengths
368        arg_parser.add_argument(
369            "--min-chars",
370            type=int,
371            required=False,
372            help="The minimum number of characters for a prompt",
373            default=100,
374        )
375        arg_parser.add_argument(
376            "--max-chars",
377            type=int,
378            required=False,
379            help="The maximum number of characters for a prompt",
380            default=1000,
381        )
382
383        # number of samples
384        arg_parser.add_argument(
385            "--n-samples",
386            "-n",
387            type=int,
388            required=False,
389            help="The max number of samples to process, do all in the file if None",
390            default=None,
391        )
392
393        # force overwrite
394        arg_parser.add_argument(
395            "--force",
396            "-f",
397            action="store_true",
398            help="If passed, will overwrite existing files",
399        )
400
401        # no index html
402        arg_parser.add_argument(
403            "--no-index-html",
404            action="store_true",
405            help="If passed, will not write an index.html file for the model",
406        )
407
408        # raw prompts
409        arg_parser.add_argument(
410            "--raw-prompts",
411            "-r",
412            action="store_true",
413            help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)",
414        )
415
416        # shuffle
417        arg_parser.add_argument(
418            "--shuffle",
419            action="store_true",
420            help="If passed, will shuffle the prompts",
421        )
422
423        args: argparse.Namespace = arg_parser.parse_args()
424
425    print(f"args parsed: {args}")
426
427    models: list[str]
428    if "," in args.model:
429        models = args.model.split(",")
430    else:
431        models = [args.model]
432
433    n_models: int = len(models)
434    for idx, model in enumerate(models):
435        print(DIVIDER_S2)
436        print(f"processing model {idx+1} / {n_models}: {model}")
437        print(DIVIDER_S2)
438
439        activations_main(
440            model_name=model,
441            save_path=args.save_path,
442            prompts_path=args.prompts,
443            raw_prompts=args.raw_prompts,
444            min_chars=args.min_chars,
445            max_chars=args.max_chars,
446            force=args.force,
447            n_samples=args.n_samples,
448            no_index_html=args.no_index_html,
449            shuffle=args.shuffle,
450        )
451
452    print(DIVIDER_S1)