Coverage for pattern_lens\attn_figure_funcs.py: 100%
10 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1"""default figure functions
3- If you are making a PR, add your new figure function here.
4- if you are using this as a library, then you can see examples here
7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
10"""
12from pattern_lens.consts import AttentionMatrix
13from pattern_lens.figure_util import (
14 AttentionMatrixFigureFunc,
15 save_matrix_wrapper,
16 Matrix2D,
17)
20ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
23def register_attn_figure_func(
24 func: AttentionMatrixFigureFunc,
25) -> AttentionMatrixFigureFunc:
26 """decorator for registering attention matrix figure function
28 if you want to add a new figure function, you should use this decorator
30 # Parameters:
31 - `func : AttentionMatrixFigureFunc`
32 your function, which should take an attention matrix and path
34 # Returns:
35 - `AttentionMatrixFigureFunc`
36 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
38 # Usage:
39 ```python
40 @register_attn_figure_func
41 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
42 fig, ax = plt.subplots(figsize=(10, 10))
43 ax.matshow(attn_matrix, cmap="viridis")
44 ax.set_title("My New Figure Function")
45 ax.axis("off")
46 plt.savefig(path / "my_new_figure_func", format="svgz")
47 plt.close(fig)
48 ```
50 """
51 global ATTENTION_MATRIX_FIGURE_FUNCS
53 ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
55 return func
58@register_attn_figure_func
59@save_matrix_wrapper(fmt="png")
60def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
61 return attn_matrix
64# some more examples:
66# @register_attn_figure_func
67# @matplotlib_figure_saver
68# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
69# ax.matshow(attn_matrix, cmap="viridis")
70# ax.set_title("Raw Attention Pattern")
71# ax.axis("off")
73# @register_attn_figure_func
74# @save_matrix_wrapper(fmt="svg")
75# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
76# return attn_matrix
78# @register_attn_figure_func
79# @save_matrix_wrapper(fmt="svgz")
80# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
81# return attn_matrix