pattern_lens.figures
code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func
1"""code for generating figures from attention patterns, using the functions decorated with `register_attn_figure_func`""" 2 3import argparse 4from collections import defaultdict 5import functools 6import itertools 7import json 8import warnings 9from pathlib import Path 10 11from muutils.json_serialize import json_serialize 12from muutils.spinner import SpinnerContext 13from muutils.parallel import run_maybe_parallel 14 15from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 16from pattern_lens.consts import ( 17 DATA_DIR, 18 AttentionMatrix, 19 SPINNER_KWARGS, 20 ActivationCacheNp, 21 DIVIDER_S1, 22 DIVIDER_S2, 23) 24from pattern_lens.indexes import ( 25 generate_functions_jsonl, 26 generate_models_jsonl, 27 generate_prompts_jsonl, 28) 29from pattern_lens.load_activations import load_activations 30 31 32class HTConfigMock: 33 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 34 35 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 36 - `n_layers: int` 37 - `n_heads: int` 38 - `model_name: str` 39 """ 40 41 def __init__(self, **kwargs): 42 self.n_layers: int 43 self.n_heads: int 44 self.model_name: str 45 self.__dict__.update(kwargs) 46 47 def serialize(self): 48 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 49 return json_serialize(self.__dict__) 50 51 @classmethod 52 def load(cls, data: dict): 53 "try to load a config from a dict, using the `__init__` method" 54 return cls(**data) 55 56 57def process_single_head( 58 layer_idx: int, 59 head_idx: int, 60 attn_pattern: AttentionMatrix, 61 save_dir: Path, 62 force_overwrite: bool = False, 63) -> dict[str, bool | Exception]: 64 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 65 66 # Parameters: 67 - `layer_idx : int` 68 - `head_idx : int` 69 - `attn_pattern : AttentionMatrix` 70 attention pattern for the head 71 - `save_dir : Path` 72 directory to save the figures to 73 - `force_overwrite : bool` 74 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 75 (defaults to `False`) 76 77 # Returns: 78 - `dict[str, bool | Exception]` 79 a dictionary of the status of each function, with the function name as the key and the status as the value 80 """ 81 funcs_status: dict[str, bool | Exception] = dict() 82 83 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 84 func_name: str = func.__name__ 85 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 86 87 if not force_overwrite and len(fig_path) > 0: 88 funcs_status[func_name] = True 89 continue 90 91 try: 92 func(attn_pattern, save_dir) 93 funcs_status[func_name] = True 94 95 except Exception as e: 96 error_file = save_dir / f"{func.__name__}.error.txt" 97 error_file.write_text(str(e)) 98 warnings.warn( 99 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}" 100 ) 101 funcs_status[func_name] = e 102 103 return funcs_status 104 105 106def compute_and_save_figures( 107 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 108 activations_path: Path, 109 cache: ActivationCacheNp, 110 save_path: Path = Path(DATA_DIR), 111 force_overwrite: bool = False, 112 track_results: bool = False, 113) -> None: 114 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 115 116 # Parameters: 117 - `model_cfg : HookedTransformerConfig|HTConfigMock` 118 - `cache : ActivationCacheNp` 119 - `save_path : Path` 120 (defaults to `Path(DATA_DIR)`) 121 - `force_overwrite : bool` 122 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 123 (defaults to `False`) 124 - `track_results : bool` 125 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 126 (defaults to `False`) 127 """ 128 prompt_dir: Path = activations_path.parent 129 130 if track_results: 131 results: defaultdict[ 132 str, # func name 133 dict[ 134 tuple[int, int], # layer, head 135 bool | Exception, # success or exception 136 ], 137 ] = defaultdict(dict) 138 139 for layer_idx, head_idx in itertools.product( 140 range(model_cfg.n_layers), 141 range(model_cfg.n_heads), 142 ): 143 attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][ 144 0, head_idx 145 ] 146 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 147 save_dir.mkdir(parents=True, exist_ok=True) 148 head_res: dict[str, bool | Exception] = process_single_head( 149 layer_idx=layer_idx, 150 head_idx=head_idx, 151 attn_pattern=attn_pattern, 152 save_dir=save_dir, 153 force_overwrite=force_overwrite, 154 ) 155 156 if track_results: 157 for func_name, status in head_res.items(): 158 results[func_name][(layer_idx, head_idx)] = status 159 160 # TODO: do something with results 161 162 generate_prompts_jsonl(save_path / model_cfg.model_name) 163 164 165def process_prompt( 166 prompt: dict, 167 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 168 save_path: Path, 169 force_overwrite: bool = False, 170) -> None: 171 """process a single prompt, loading the activations and computing and saving the figures 172 173 basically just calls `load_activations` and then `compute_and_save_figures` 174 175 # Parameters: 176 - `prompt : dict` 177 - `model_cfg : HookedTransformerConfig|HTConfigMock` 178 - `force_overwrite : bool` 179 (defaults to `False`) 180 """ 181 activations_path: Path 182 cache: ActivationCacheNp 183 activations_path, cache = load_activations( 184 model_name=model_cfg.model_name, 185 prompt=prompt, 186 save_path=save_path, 187 return_fmt="numpy", 188 ) 189 190 compute_and_save_figures( 191 model_cfg=model_cfg, 192 activations_path=activations_path, 193 cache=cache, 194 save_path=save_path, 195 force_overwrite=force_overwrite, 196 ) 197 198 199def figures_main( 200 model_name: str, 201 save_path: str, 202 n_samples: int, 203 force: bool, 204 parallel: bool | int = True, 205) -> None: 206 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 207 208 # Parameters: 209 - `model_name : str` 210 model name to use, used for loading the model config, prompts, activations, and saving the figures 211 - `save_path : str` 212 base path to look in 213 - `n_samples : int` 214 max number of samples to process 215 - `force : bool` 216 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 217 - `parallel : bool | int` 218 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 219 (defaults to `True`) 220 """ 221 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 222 # save model info or check if it exists 223 save_path_p: Path = Path(save_path) 224 model_path: Path = save_path_p / model_name 225 with open(model_path / "model_cfg.json", "r") as f: 226 model_cfg = HTConfigMock.load(json.load(f)) 227 228 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 229 # load prompts 230 with open(model_path / "prompts.jsonl", "r") as f: 231 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 232 # truncate to n_samples 233 prompts = prompts[:n_samples] 234 235 print(f"{len(prompts)} prompts loaded") 236 237 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 238 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 239 240 list( 241 run_maybe_parallel( 242 func=functools.partial( 243 process_prompt, 244 model_cfg=model_cfg, 245 save_path=save_path_p, 246 force_overwrite=force, 247 ), 248 iterable=prompts, 249 parallel=parallel, 250 pbar="tqdm", 251 pbar_kwargs=dict( 252 desc="Making figures", 253 unit="prompt", 254 ), 255 ) 256 ) 257 258 with SpinnerContext( 259 message="updating jsonl metadata for models and functions", **SPINNER_KWARGS 260 ): 261 generate_models_jsonl(save_path_p) 262 generate_functions_jsonl(save_path_p) 263 264 265def main(): 266 print(DIVIDER_S1) 267 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 268 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 269 # input and output 270 arg_parser.add_argument( 271 "--model", 272 "-m", 273 type=str, 274 required=True, 275 help="The model name(s) to use. comma separated with no whitespace if multiple", 276 ) 277 arg_parser.add_argument( 278 "--save-path", 279 "-s", 280 type=str, 281 required=False, 282 help="The path to save the attention patterns", 283 default=DATA_DIR, 284 ) 285 # number of samples 286 arg_parser.add_argument( 287 "--n-samples", 288 "-n", 289 type=int, 290 required=False, 291 help="The max number of samples to process, do all in the file if None", 292 default=None, 293 ) 294 # force overwrite of existing figures 295 arg_parser.add_argument( 296 "--force", 297 "-f", 298 type=bool, 299 required=False, 300 help="Force overwrite of existing figures", 301 default=False, 302 ) 303 304 args: argparse.Namespace = arg_parser.parse_args() 305 306 print(f"args parsed: {args}") 307 308 models: list[str] 309 if "," in args.model: 310 models = args.model.split(",") 311 else: 312 models = [args.model] 313 314 n_models: int = len(models) 315 for idx, model in enumerate(models): 316 print(DIVIDER_S2) 317 print(f"processing model {idx+1} / {n_models}: {model}") 318 print(DIVIDER_S2) 319 figures_main( 320 model_name=model, 321 save_path=args.save_path, 322 n_samples=args.n_samples, 323 force=args.force, 324 ) 325 326 print(DIVIDER_S1) 327 328 329if __name__ == "__main__": 330 main()
33class HTConfigMock: 34 """Mock of `transformer_lens.HookedTransformerConfig` for type hinting and loading config json 35 36 can be initialized with any kwargs, and will update its `__dict__` with them. does, however, require the following attributes: 37 - `n_layers: int` 38 - `n_heads: int` 39 - `model_name: str` 40 """ 41 42 def __init__(self, **kwargs): 43 self.n_layers: int 44 self.n_heads: int 45 self.model_name: str 46 self.__dict__.update(kwargs) 47 48 def serialize(self): 49 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 50 return json_serialize(self.__dict__) 51 52 @classmethod 53 def load(cls, data: dict): 54 "try to load a config from a dict, using the `__init__` method" 55 return cls(**data)
Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json
can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes:
n_layers: intn_heads: intmodel_name: str
48 def serialize(self): 49 """serialize the config to json. values which aren't serializable will be converted via `muutils.json_serialize.json_serialize`""" 50 return json_serialize(self.__dict__)
serialize the config to json. values which aren't serializable will be converted via muutils.json_serialize.json_serialize
58def process_single_head( 59 layer_idx: int, 60 head_idx: int, 61 attn_pattern: AttentionMatrix, 62 save_dir: Path, 63 force_overwrite: bool = False, 64) -> dict[str, bool | Exception]: 65 """process a single head's attention pattern, running all the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` on the attention pattern 66 67 # Parameters: 68 - `layer_idx : int` 69 - `head_idx : int` 70 - `attn_pattern : AttentionMatrix` 71 attention pattern for the head 72 - `save_dir : Path` 73 directory to save the figures to 74 - `force_overwrite : bool` 75 whether to overwrite existing figures. if `False`, will skip any functions which have already saved a figure 76 (defaults to `False`) 77 78 # Returns: 79 - `dict[str, bool | Exception]` 80 a dictionary of the status of each function, with the function name as the key and the status as the value 81 """ 82 funcs_status: dict[str, bool | Exception] = dict() 83 84 for func in ATTENTION_MATRIX_FIGURE_FUNCS: 85 func_name: str = func.__name__ 86 fig_path: list[Path] = list(save_dir.glob(f"{func_name}.*")) 87 88 if not force_overwrite and len(fig_path) > 0: 89 funcs_status[func_name] = True 90 continue 91 92 try: 93 func(attn_pattern, save_dir) 94 funcs_status[func_name] = True 95 96 except Exception as e: 97 error_file = save_dir / f"{func.__name__}.error.txt" 98 error_file.write_text(str(e)) 99 warnings.warn( 100 f"Error in {func.__name__} for L{layer_idx}H{head_idx}: {str(e)}" 101 ) 102 funcs_status[func_name] = e 103 104 return funcs_status
process a single head's attention pattern, running all the functions in ATTENTION_MATRIX_FIGURE_FUNCS on the attention pattern
Parameters:
layer_idx : inthead_idx : intattn_pattern : AttentionMatrixattention pattern for the headsave_dir : Pathdirectory to save the figures toforce_overwrite : boolwhether to overwrite existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)
Returns:
dict[str, bool | Exception]a dictionary of the status of each function, with the function name as the key and the status as the value
107def compute_and_save_figures( 108 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 109 activations_path: Path, 110 cache: ActivationCacheNp, 111 save_path: Path = Path(DATA_DIR), 112 force_overwrite: bool = False, 113 track_results: bool = False, 114) -> None: 115 """compute and save figures for all heads in the model, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 116 117 # Parameters: 118 - `model_cfg : HookedTransformerConfig|HTConfigMock` 119 - `cache : ActivationCacheNp` 120 - `save_path : Path` 121 (defaults to `Path(DATA_DIR)`) 122 - `force_overwrite : bool` 123 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 124 (defaults to `False`) 125 - `track_results : bool` 126 whether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO 127 (defaults to `False`) 128 """ 129 prompt_dir: Path = activations_path.parent 130 131 if track_results: 132 results: defaultdict[ 133 str, # func name 134 dict[ 135 tuple[int, int], # layer, head 136 bool | Exception, # success or exception 137 ], 138 ] = defaultdict(dict) 139 140 for layer_idx, head_idx in itertools.product( 141 range(model_cfg.n_layers), 142 range(model_cfg.n_heads), 143 ): 144 attn_pattern: AttentionMatrix = cache[f"blocks.{layer_idx}.attn.hook_pattern"][ 145 0, head_idx 146 ] 147 save_dir: Path = prompt_dir / f"L{layer_idx}" / f"H{head_idx}" 148 save_dir.mkdir(parents=True, exist_ok=True) 149 head_res: dict[str, bool | Exception] = process_single_head( 150 layer_idx=layer_idx, 151 head_idx=head_idx, 152 attn_pattern=attn_pattern, 153 save_dir=save_dir, 154 force_overwrite=force_overwrite, 155 ) 156 157 if track_results: 158 for func_name, status in head_res.items(): 159 results[func_name][(layer_idx, head_idx)] = status 160 161 # TODO: do something with results 162 163 generate_prompts_jsonl(save_path / model_cfg.model_name)
compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_cfg : HookedTransformerConfig|HTConfigMockcache : ActivationCacheNpsave_path : Path(defaults toPath(DATA_DIR))force_overwrite : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figure (defaults toFalse)track_results : boolwhether to track the results of each function for each head. Isn't used for anything yet, but this is a TODO (defaults toFalse)
166def process_prompt( 167 prompt: dict, 168 model_cfg: "HookedTransformerConfig|HTConfigMock", # type: ignore[name-defined] # noqa: F821 169 save_path: Path, 170 force_overwrite: bool = False, 171) -> None: 172 """process a single prompt, loading the activations and computing and saving the figures 173 174 basically just calls `load_activations` and then `compute_and_save_figures` 175 176 # Parameters: 177 - `prompt : dict` 178 - `model_cfg : HookedTransformerConfig|HTConfigMock` 179 - `force_overwrite : bool` 180 (defaults to `False`) 181 """ 182 activations_path: Path 183 cache: ActivationCacheNp 184 activations_path, cache = load_activations( 185 model_name=model_cfg.model_name, 186 prompt=prompt, 187 save_path=save_path, 188 return_fmt="numpy", 189 ) 190 191 compute_and_save_figures( 192 model_cfg=model_cfg, 193 activations_path=activations_path, 194 cache=cache, 195 save_path=save_path, 196 force_overwrite=force_overwrite, 197 )
process a single prompt, loading the activations and computing and saving the figures
basically just calls load_activations and then compute_and_save_figures
Parameters:
prompt : dictmodel_cfg : HookedTransformerConfig|HTConfigMockforce_overwrite : bool(defaults toFalse)
200def figures_main( 201 model_name: str, 202 save_path: str, 203 n_samples: int, 204 force: bool, 205 parallel: bool | int = True, 206) -> None: 207 """main function for generating figures from attention patterns, using the functions in `ATTENTION_MATRIX_FIGURE_FUNCS` 208 209 # Parameters: 210 - `model_name : str` 211 model name to use, used for loading the model config, prompts, activations, and saving the figures 212 - `save_path : str` 213 base path to look in 214 - `n_samples : int` 215 max number of samples to process 216 - `force : bool` 217 force overwrite of existing figures. if `False`, will skip any functions which have already saved a figure 218 - `parallel : bool | int` 219 whether to run in parallel. if `True`, will use all available cores. if `False`, will run in serial. if an int, will try to use that many cores 220 (defaults to `True`) 221 """ 222 with SpinnerContext(message="setting up paths", **SPINNER_KWARGS): 223 # save model info or check if it exists 224 save_path_p: Path = Path(save_path) 225 model_path: Path = save_path_p / model_name 226 with open(model_path / "model_cfg.json", "r") as f: 227 model_cfg = HTConfigMock.load(json.load(f)) 228 229 with SpinnerContext(message="loading prompts", **SPINNER_KWARGS): 230 # load prompts 231 with open(model_path / "prompts.jsonl", "r") as f: 232 prompts: list[dict] = [json.loads(line) for line in f.readlines()] 233 # truncate to n_samples 234 prompts = prompts[:n_samples] 235 236 print(f"{len(prompts)} prompts loaded") 237 238 print(f"{len(ATTENTION_MATRIX_FIGURE_FUNCS)} figure functions loaded") 239 print("\t" + ", ".join([func.__name__ for func in ATTENTION_MATRIX_FIGURE_FUNCS])) 240 241 list( 242 run_maybe_parallel( 243 func=functools.partial( 244 process_prompt, 245 model_cfg=model_cfg, 246 save_path=save_path_p, 247 force_overwrite=force, 248 ), 249 iterable=prompts, 250 parallel=parallel, 251 pbar="tqdm", 252 pbar_kwargs=dict( 253 desc="Making figures", 254 unit="prompt", 255 ), 256 ) 257 ) 258 259 with SpinnerContext( 260 message="updating jsonl metadata for models and functions", **SPINNER_KWARGS 261 ): 262 generate_models_jsonl(save_path_p) 263 generate_functions_jsonl(save_path_p)
main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS
Parameters:
model_name : strmodel name to use, used for loading the model config, prompts, activations, and saving the figuressave_path : strbase path to look inn_samples : intmax number of samples to processforce : boolforce overwrite of existing figures. ifFalse, will skip any functions which have already saved a figureparallel : bool | intwhether to run in parallel. ifTrue, will use all available cores. ifFalse, will run in serial. if an int, will try to use that many cores (defaults toTrue)
266def main(): 267 print(DIVIDER_S1) 268 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 269 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 270 # input and output 271 arg_parser.add_argument( 272 "--model", 273 "-m", 274 type=str, 275 required=True, 276 help="The model name(s) to use. comma separated with no whitespace if multiple", 277 ) 278 arg_parser.add_argument( 279 "--save-path", 280 "-s", 281 type=str, 282 required=False, 283 help="The path to save the attention patterns", 284 default=DATA_DIR, 285 ) 286 # number of samples 287 arg_parser.add_argument( 288 "--n-samples", 289 "-n", 290 type=int, 291 required=False, 292 help="The max number of samples to process, do all in the file if None", 293 default=None, 294 ) 295 # force overwrite of existing figures 296 arg_parser.add_argument( 297 "--force", 298 "-f", 299 type=bool, 300 required=False, 301 help="Force overwrite of existing figures", 302 default=False, 303 ) 304 305 args: argparse.Namespace = arg_parser.parse_args() 306 307 print(f"args parsed: {args}") 308 309 models: list[str] 310 if "," in args.model: 311 models = args.model.split(",") 312 else: 313 models = [args.model] 314 315 n_models: int = len(models) 316 for idx, model in enumerate(models): 317 print(DIVIDER_S2) 318 print(f"processing model {idx+1} / {n_models}: {model}") 319 print(DIVIDER_S2) 320 figures_main( 321 model_name=model, 322 save_path=args.save_path, 323 n_samples=args.n_samples, 324 force=args.force, 325 ) 326 327 print(DIVIDER_S1)