Coverage for pattern_lens\indexes.py: 96%
57 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1"""writes indexes to the model directory for the frontend to use or for record keeping"""
3import inspect
4import json
5from pathlib import Path
6import importlib.resources
7from typing import Callable
9import pattern_lens
10from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS
13def generate_prompts_jsonl(model_dir: Path):
14 """creates a `prompts.jsonl` file with all the prompts in the model directory
16 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file
17 """
18 prompts: list[dict] = list()
19 for prompt_dir in (model_dir / "prompts").iterdir():
20 prompt_file: Path = prompt_dir / "prompt.json"
21 if prompt_file.exists():
22 with open(prompt_file, "r") as f:
23 prompt_data: dict = json.load(f)
24 prompts.append(prompt_data)
26 with open(model_dir / "prompts.jsonl", "w") as f:
27 for prompt in prompts:
28 f.write(json.dumps(prompt))
29 f.write("\n")
32def generate_models_jsonl(path: Path):
33 """creates a `models.jsonl` file with all the models"""
34 models: list[dict] = list()
35 for model_dir in (path).iterdir():
36 model_cfg_path: Path = model_dir / "model_cfg.json"
37 if model_cfg_path.exists():
38 with open(model_cfg_path, "r") as f:
39 model_cfg: dict = json.load(f)
40 models.append(model_cfg)
42 with open(path / "models.jsonl", "w") as f:
43 for model in models:
44 f.write(json.dumps(model))
45 f.write("\n")
48def get_func_metadata(func: Callable) -> dict[str, str | None]:
49 """get metadata for a function
51 # Parameters:
52 - `func : Callable`
54 # Returns:
56 `dict[str, str | None]`
57 dictionary:
59 - `name : str` : the name of the function
60 - `doc : str` : the docstring of the function
61 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist
62 - `source : str | None` : the source file of the function
63 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read
65 """
66 source_file: str | None = inspect.getsourcefile(func)
67 output: dict[str, str | None] = dict(
68 name=func.__name__,
69 doc=func.__doc__,
70 figure_save_fmt=getattr(func, "figure_save_fmt", None),
71 source=Path(source_file).as_posix() if source_file else None,
72 )
74 try:
75 output["code"] = inspect.getsource(func)
76 except OSError:
77 output["code"] = None
79 return output
82def generate_functions_jsonl(path: Path):
83 "unions all functions from file and current `ATTENTION_MATRIX_FIGURE_FUNCS` into a `functions.jsonl` file"
84 functions_file: Path = path / "functions.jsonl"
85 existing_functions: dict[str, dict] = dict()
87 if functions_file.exists():
88 with open(functions_file, "r") as f:
89 for line in f:
90 func_data: dict = json.loads(line)
91 existing_functions[func_data["name"]] = func_data
93 # Add any new functions from ALL_FUNCTIONS
94 new_functions: dict[str, dict] = {
95 func.__name__: get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS
96 }
98 all_functions: list[dict] = list(
99 {
100 **existing_functions,
101 **new_functions,
102 }.values()
103 )
105 with open(functions_file, "w") as f:
106 for func_meta in sorted(all_functions, key=lambda x: x["name"]):
107 json.dump(func_meta, f)
108 f.write("\n")
111def write_html_index(path: Path):
112 """writes an index.html file to the path"""
113 html_index: str = (
114 importlib.resources.files(pattern_lens)
115 .joinpath("frontend/index.html")
116 .read_text(encoding="utf-8")
117 )
118 with open(path / "index.html", "w", encoding="utf-8") as f:
119 f.write(html_index)