Coverage for pattern_lens\activations.py: 67%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1"""computing and saving activations given a model and prompts 

2 

3 

4# Usage: 

5 

6from the command line: 

7 

8```bash 

9python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples> 

10``` 

11 

12from a script: 

13 

14```python 

15from pattern_lens.activations import activations_main 

16activations_main( 

17 model_name="gpt2", 

18 save_path="demo/" 

19 prompts_path="data/pile_1k.jsonl", 

20) 

21``` 

22 

23""" 

24 

25import argparse 

26import functools 

27import json 

28from dataclasses import asdict 

29from pathlib import Path 

30import re 

31from typing import Callable, Literal, overload 

32 

33import numpy as np 

34import torch 

35import tqdm 

36from muutils.spinner import SpinnerContext 

37from muutils.misc.numerical import shorten_numerical_to_str 

38from muutils.json_serialize import json_serialize 

39from transformer_lens import HookedTransformer, HookedTransformerConfig # type: ignore[import-untyped] 

40 

41from pattern_lens.consts import ( 

42 ATTN_PATTERN_REGEX, 

43 DATA_DIR, 

44 ActivationCacheNp, 

45 SPINNER_KWARGS, 

46 DIVIDER_S1, 

47 DIVIDER_S2, 

48) 

49from pattern_lens.indexes import ( 

50 generate_models_jsonl, 

51 generate_prompts_jsonl, 

52 write_html_index, 

53) 

54from pattern_lens.load_activations import ( 

55 ActivationsMissingError, 

56 augment_prompt_with_hash, 

57 load_activations, 

58) 

59from pattern_lens.prompts import load_text_data 

60 

61 

62def compute_activations( 

63 prompt: dict, 

64 model: HookedTransformer | None = None, 

65 save_path: Path = Path(DATA_DIR), 

66 return_cache: bool = True, 

67 names_filter: Callable[[str], bool] | re.Pattern = ATTN_PATTERN_REGEX, 

68) -> tuple[Path, ActivationCacheNp | None]: 

69 """get activations for a given model and prompt, possibly from a cache 

70 

71 if from a cache, prompt_meta must be passed and contain the prompt hash 

72 

73 # Parameters: 

74 - `prompt : dict | None` 

75 (defaults to `None`) 

76 - `model : HookedTransformer` 

77 - `save_path : Path` 

78 (defaults to `Path(DATA_DIR)`) 

79 - `return_cache : bool` 

80 will return `None` as the second element if `False` 

81 (defaults to `True`) 

82 - `names_filter : Callable[[str], bool]|re.Pattern` 

83 a filter for the names of the activations to return. if an `re.Pattern`, will use `lambda key: names_filter.match(key) is not None` 

84 (defaults to `ATTN_PATTERN_REGEX`) 

85 

86 # Returns: 

87 - `tuple[Path, ActivationCacheNp|None]` 

88 """ 

89 assert model is not None, "model must be passed" 

90 assert "text" in prompt, "prompt must contain 'text' key" 

91 prompt_str: str = prompt["text"] 

92 

93 # compute or get prompt metadata 

94 prompt_tokenized: list[str] = prompt.get( 

95 "tokens", 

96 model.tokenizer.tokenize(prompt_str), 

97 ) 

98 prompt.update( 

99 dict( 

100 n_tokens=len(prompt_tokenized), 

101 tokens=prompt_tokenized, 

102 ) 

103 ) 

104 

105 # save metadata 

106 prompt_dir: Path = save_path / model.model_name / "prompts" / prompt["hash"] 

107 prompt_dir.mkdir(parents=True, exist_ok=True) 

108 with open(prompt_dir / "prompt.json", "w") as f: 

109 json.dump(prompt, f) 

110 

111 # set up names filter 

112 names_filter_fn: Callable[[str], bool] 

113 if isinstance(names_filter, re.Pattern): 

114 names_filter_fn = lambda key: names_filter.match(key) is not None # noqa: E731 

115 else: 

116 names_filter_fn = names_filter 

117 

118 # compute activations 

119 with torch.no_grad(): 

120 # TODO: batching? 

121 _, cache = model.run_with_cache( 

122 prompt_str, 

123 names_filter=names_filter_fn, 

124 return_type=None, 

125 ) 

126 

127 cache_np: ActivationCacheNp = { 

128 k: v.detach().cpu().numpy() for k, v in cache.items() 

129 } 

130 

131 # save activations 

132 activations_path: Path = prompt_dir / "activations.npz" 

133 np.savez_compressed( 

134 activations_path, 

135 **cache_np, 

136 ) 

137 

138 # return path and cache 

139 if return_cache: 

140 return activations_path, cache_np 

141 else: 

142 return activations_path, None 

143 

144 

145@overload 

146def get_activations( 

147 prompt: dict, 

148 model: HookedTransformer | str, 

149 save_path: Path = Path(DATA_DIR), 

150 allow_disk_cache: bool = True, 

151 return_cache: Literal[False] = False, 

152) -> tuple[Path, None]: ... 

153@overload 

154def get_activations( 

155 prompt: dict, 

156 model: HookedTransformer | str, 

157 save_path: Path = Path(DATA_DIR), 

158 allow_disk_cache: bool = True, 

159 return_cache: Literal[True] = True, 

160) -> tuple[Path, ActivationCacheNp]: ... 

161def get_activations( 

162 prompt: dict, 

163 model: HookedTransformer | str, 

164 save_path: Path = Path(DATA_DIR), 

165 allow_disk_cache: bool = True, 

166 return_cache: bool = True, 

167) -> tuple[Path, ActivationCacheNp | None]: 

168 """given a prompt and a model, save or load activations 

169 

170 # Parameters: 

171 - `prompt : dict` 

172 expected to contain the 'text' key 

173 - `model : HookedTransformer | str` 

174 either a `HookedTransformer` or a string model name, to be loaded with `HookedTransformer.from_pretrained` 

175 - `save_path : Path` 

176 path to save the activations to (and load from) 

177 (defaults to `Path(DATA_DIR)`) 

178 - `allow_disk_cache : bool` 

179 whether to allow loading from disk cache 

180 (defaults to `True`) 

181 - `return_cache : bool` 

182 whether to return the cache. if `False`, will return `None` as the second element 

183 (defaults to `True`) 

184 

185 # Returns: 

186 - `tuple[Path, ActivationCacheNp | None]` 

187 the path to the activations and the cache if `return_cache` is `True` 

188 

189 """ 

190 # add hash to prompt 

191 augment_prompt_with_hash(prompt) 

192 

193 # get the model 

194 model_name: str = ( 

195 model.model_name if isinstance(model, HookedTransformer) else model 

196 ) 

197 

198 # from cache 

199 if allow_disk_cache: 

200 try: 

201 path, cache = load_activations( 

202 model_name=model_name, 

203 prompt=prompt, 

204 save_path=save_path, 

205 ) 

206 if return_cache: 

207 return path, cache 

208 else: 

209 return path, None 

210 except ActivationsMissingError: 

211 pass 

212 

213 # compute them 

214 if isinstance(model, str): 

215 model = HookedTransformer.from_pretrained(model_name) 

216 

217 return compute_activations( 

218 prompt=prompt, 

219 model=model, 

220 save_path=save_path, 

221 return_cache=True, 

222 ) 

223 

224 

225def activations_main( 

226 model_name: str, 

227 save_path: str, 

228 prompts_path: str, 

229 raw_prompts: bool, 

230 min_chars: int, 

231 max_chars: int, 

232 force: bool, 

233 n_samples: int, 

234 no_index_html: bool, 

235 shuffle: bool = False, 

236) -> None: 

237 """main function for computing activations 

238 

239 # Parameters: 

240 - `model_name : str` 

241 name of a model to load with `HookedTransformer.from_pretrained` 

242 - `save_path : str` 

243 path to save the activations to 

244 - `prompts_path : str` 

245 path to the prompts file 

246 - `raw_prompts : bool` 

247 whether the prompts are raw, not filtered by length. `load_text_data` will be called if `True`, otherwise just load the "text" field from each line in `prompts_path` 

248 - `min_chars : int` 

249 minimum number of characters for a prompt 

250 - `max_chars : int` 

251 maximum number of characters for a prompt 

252 - `force : bool` 

253 whether to overwrite existing files 

254 - `n_samples : int` 

255 maximum number of samples to process 

256 - `no_index_html : bool` 

257 whether to write an index.html file 

258 - `shuffle : bool` 

259 whether to shuffle the prompts 

260 (defaults to `False`) 

261 """ 

262 

263 with SpinnerContext(message="loading model", **SPINNER_KWARGS): 

264 model: HookedTransformer = HookedTransformer.from_pretrained(model_name) 

265 model.model_name = model_name 

266 model.cfg.model_name = model_name 

267 n_params: int = sum(p.numel() for p in model.parameters()) 

268 print( 

269 f"loaded {model_name} with {shorten_numerical_to_str(n_params)} ({n_params}) parameters" 

270 ) 

271 

272 save_path_p: Path = Path(save_path) 

273 save_path_p.mkdir(parents=True, exist_ok=True) 

274 model_path: Path = save_path_p / model_name 

275 with SpinnerContext( 

276 message=f"saving model info to {model_path.as_posix()}", **SPINNER_KWARGS 

277 ): 

278 model_cfg: HookedTransformerConfig 

279 model_cfg = model.cfg 

280 model_path.mkdir(parents=True, exist_ok=True) 

281 with open(model_path / "model_cfg.json", "w") as f: 

282 json.dump(json_serialize(asdict(model_cfg)), f) 

283 

284 # load prompts 

285 with SpinnerContext( 

286 message=f"loading prompts from {prompts_path = }", **SPINNER_KWARGS 

287 ): 

288 prompts: list[dict] 

289 if raw_prompts: 

290 prompts = load_text_data( 

291 Path(prompts_path), 

292 min_chars=min_chars, 

293 max_chars=max_chars, 

294 shuffle=shuffle, 

295 ) 

296 else: 

297 with open(model_path / "prompts.jsonl", "r") as f: 

298 prompts = [json.loads(line) for line in f.readlines()] 

299 # truncate to n_samples 

300 prompts = prompts[:n_samples] 

301 

302 print(f"{len(prompts)} prompts loaded") 

303 

304 # write index.html 

305 with SpinnerContext(message="writing index.html", **SPINNER_KWARGS): 

306 if not no_index_html: 

307 write_html_index(save_path_p) 

308 

309 # get activations 

310 list( 

311 tqdm.tqdm( 

312 map( 

313 functools.partial( 

314 get_activations, 

315 model=model, 

316 save_path=save_path_p, 

317 allow_disk_cache=not force, 

318 return_cache=False, 

319 ), 

320 prompts, 

321 ), 

322 total=len(prompts), 

323 desc="Computing activations", 

324 unit="prompt", 

325 ) 

326 ) 

327 

328 with SpinnerContext( 

329 message="updating jsonl metadata for models and prompts", **SPINNER_KWARGS 

330 ): 

331 generate_models_jsonl(save_path_p) 

332 generate_prompts_jsonl(save_path_p / model_name) 

333 

334 

335def main(): 

336 print(DIVIDER_S1) 

337 with SpinnerContext(message="parsing args", **SPINNER_KWARGS): 

338 arg_parser: argparse.ArgumentParser = argparse.ArgumentParser() 

339 # input and output 

340 arg_parser.add_argument( 

341 "--model", 

342 "-m", 

343 type=str, 

344 required=True, 

345 help="The model name(s) to use. comma separated with no whitespace if multiple", 

346 ) 

347 

348 arg_parser.add_argument( 

349 "--prompts", 

350 "-p", 

351 type=str, 

352 required=False, 

353 help="The path to the prompts file (jsonl with 'text' key on each line). If `None`, expects that `--figures` is passed and will generate figures for all prompts in the model directory", 

354 default=None, 

355 ) 

356 

357 arg_parser.add_argument( 

358 "--save-path", 

359 "-s", 

360 type=str, 

361 required=False, 

362 help="The path to save the attention patterns", 

363 default=DATA_DIR, 

364 ) 

365 

366 # min and max prompt lengths 

367 arg_parser.add_argument( 

368 "--min-chars", 

369 type=int, 

370 required=False, 

371 help="The minimum number of characters for a prompt", 

372 default=100, 

373 ) 

374 arg_parser.add_argument( 

375 "--max-chars", 

376 type=int, 

377 required=False, 

378 help="The maximum number of characters for a prompt", 

379 default=1000, 

380 ) 

381 

382 # number of samples 

383 arg_parser.add_argument( 

384 "--n-samples", 

385 "-n", 

386 type=int, 

387 required=False, 

388 help="The max number of samples to process, do all in the file if None", 

389 default=None, 

390 ) 

391 

392 # force overwrite 

393 arg_parser.add_argument( 

394 "--force", 

395 "-f", 

396 action="store_true", 

397 help="If passed, will overwrite existing files", 

398 ) 

399 

400 # no index html 

401 arg_parser.add_argument( 

402 "--no-index-html", 

403 action="store_true", 

404 help="If passed, will not write an index.html file for the model", 

405 ) 

406 

407 # raw prompts 

408 arg_parser.add_argument( 

409 "--raw-prompts", 

410 "-r", 

411 action="store_true", 

412 help="pass if the prompts have not been split and tokenized (still needs keys 'text' and 'meta' for each item)", 

413 ) 

414 

415 # shuffle 

416 arg_parser.add_argument( 

417 "--shuffle", 

418 action="store_true", 

419 help="If passed, will shuffle the prompts", 

420 ) 

421 

422 args: argparse.Namespace = arg_parser.parse_args() 

423 

424 print(f"args parsed: {args}") 

425 

426 models: list[str] 

427 if "," in args.model: 

428 models = args.model.split(",") 

429 else: 

430 models = [args.model] 

431 

432 n_models: int = len(models) 

433 for idx, model in enumerate(models): 

434 print(DIVIDER_S2) 

435 print(f"processing model {idx+1} / {n_models}: {model}") 

436 print(DIVIDER_S2) 

437 

438 activations_main( 

439 model_name=model, 

440 save_path=args.save_path, 

441 prompts_path=args.prompts, 

442 raw_prompts=args.raw_prompts, 

443 min_chars=args.min_chars, 

444 max_chars=args.max_chars, 

445 force=args.force, 

446 n_samples=args.n_samples, 

447 no_index_html=args.no_index_html, 

448 shuffle=args.shuffle, 

449 ) 

450 

451 print(DIVIDER_S1) 

452 

453 

454if __name__ == "__main__": 

455 main()