Coverage for pattern_lens\indexes.py: 96%

57 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1"""writes indexes to the model directory for the frontend to use or for record keeping""" 

2 

3import inspect 

4import json 

5from pathlib import Path 

6import importlib.resources 

7from typing import Callable 

8 

9import pattern_lens 

10from pattern_lens.attn_figure_funcs import ATTENTION_MATRIX_FIGURE_FUNCS 

11 

12 

13def generate_prompts_jsonl(model_dir: Path): 

14 """creates a `prompts.jsonl` file with all the prompts in the model directory 

15 

16 looks in all directories in `{model_dir}/prompts` for a `prompt.json` file 

17 """ 

18 prompts: list[dict] = list() 

19 for prompt_dir in (model_dir / "prompts").iterdir(): 

20 prompt_file: Path = prompt_dir / "prompt.json" 

21 if prompt_file.exists(): 

22 with open(prompt_file, "r") as f: 

23 prompt_data: dict = json.load(f) 

24 prompts.append(prompt_data) 

25 

26 with open(model_dir / "prompts.jsonl", "w") as f: 

27 for prompt in prompts: 

28 f.write(json.dumps(prompt)) 

29 f.write("\n") 

30 

31 

32def generate_models_jsonl(path: Path): 

33 """creates a `models.jsonl` file with all the models""" 

34 models: list[dict] = list() 

35 for model_dir in (path).iterdir(): 

36 model_cfg_path: Path = model_dir / "model_cfg.json" 

37 if model_cfg_path.exists(): 

38 with open(model_cfg_path, "r") as f: 

39 model_cfg: dict = json.load(f) 

40 models.append(model_cfg) 

41 

42 with open(path / "models.jsonl", "w") as f: 

43 for model in models: 

44 f.write(json.dumps(model)) 

45 f.write("\n") 

46 

47 

48def get_func_metadata(func: Callable) -> dict[str, str | None]: 

49 """get metadata for a function 

50 

51 # Parameters: 

52 - `func : Callable` 

53 

54 # Returns: 

55 

56 `dict[str, str | None]` 

57 dictionary: 

58 

59 - `name : str` : the name of the function 

60 - `doc : str` : the docstring of the function 

61 - `figure_save_fmt : str | None` : the format of the figure that the function saves, using the `figure_save_fmt` attribute of the function. `None` if the attribute does not exist 

62 - `source : str | None` : the source file of the function 

63 - `code : str | None` : the source code of the function, split by line. `None` if the source file cannot be read 

64 

65 """ 

66 source_file: str | None = inspect.getsourcefile(func) 

67 output: dict[str, str | None] = dict( 

68 name=func.__name__, 

69 doc=func.__doc__, 

70 figure_save_fmt=getattr(func, "figure_save_fmt", None), 

71 source=Path(source_file).as_posix() if source_file else None, 

72 ) 

73 

74 try: 

75 output["code"] = inspect.getsource(func) 

76 except OSError: 

77 output["code"] = None 

78 

79 return output 

80 

81 

82def generate_functions_jsonl(path: Path): 

83 "unions all functions from file and current `ATTENTION_MATRIX_FIGURE_FUNCS` into a `functions.jsonl` file" 

84 functions_file: Path = path / "functions.jsonl" 

85 existing_functions: dict[str, dict] = dict() 

86 

87 if functions_file.exists(): 

88 with open(functions_file, "r") as f: 

89 for line in f: 

90 func_data: dict = json.loads(line) 

91 existing_functions[func_data["name"]] = func_data 

92 

93 # Add any new functions from ALL_FUNCTIONS 

94 new_functions: dict[str, dict] = { 

95 func.__name__: get_func_metadata(func) for func in ATTENTION_MATRIX_FIGURE_FUNCS 

96 } 

97 

98 all_functions: list[dict] = list( 

99 { 

100 **existing_functions, 

101 **new_functions, 

102 }.values() 

103 ) 

104 

105 with open(functions_file, "w") as f: 

106 for func_meta in sorted(all_functions, key=lambda x: x["name"]): 

107 json.dump(func_meta, f) 

108 f.write("\n") 

109 

110 

111def write_html_index(path: Path): 

112 """writes an index.html file to the path""" 

113 html_index: str = ( 

114 importlib.resources.files(pattern_lens) 

115 .joinpath("frontend/index.html") 

116 .read_text(encoding="utf-8") 

117 ) 

118 with open(path / "index.html", "w", encoding="utf-8") as f: 

119 f.write(html_index)