pattern_lens.attn_figure_funcs
default figure functions
- If you are making a PR, add your new figure function here.
- if you are using this as a library, then you can see examples here
note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator
which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS
1"""default figure functions 2 3- If you are making a PR, add your new figure function here. 4- if you are using this as a library, then you can see examples here 5 6 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 9 10""" 11 12from pattern_lens.consts import AttentionMatrix 13from pattern_lens.figure_util import ( 14 AttentionMatrixFigureFunc, 15 save_matrix_wrapper, 16 Matrix2D, 17) 18 19 20ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 21 22 23def register_attn_figure_func( 24 func: AttentionMatrixFigureFunc, 25) -> AttentionMatrixFigureFunc: 26 """decorator for registering attention matrix figure function 27 28 if you want to add a new figure function, you should use this decorator 29 30 # Parameters: 31 - `func : AttentionMatrixFigureFunc` 32 your function, which should take an attention matrix and path 33 34 # Returns: 35 - `AttentionMatrixFigureFunc` 36 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 37 38 # Usage: 39 ```python 40 @register_attn_figure_func 41 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 42 fig, ax = plt.subplots(figsize=(10, 10)) 43 ax.matshow(attn_matrix, cmap="viridis") 44 ax.set_title("My New Figure Function") 45 ax.axis("off") 46 plt.savefig(path / "my_new_figure_func", format="svgz") 47 plt.close(fig) 48 ``` 49 50 """ 51 global ATTENTION_MATRIX_FIGURE_FUNCS 52 53 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 54 55 return func 56 57 58@register_attn_figure_func 59@save_matrix_wrapper(fmt="png") 60def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 61 return attn_matrix 62 63 64# some more examples: 65 66# @register_attn_figure_func 67# @matplotlib_figure_saver 68# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 69# ax.matshow(attn_matrix, cmap="viridis") 70# ax.set_title("Raw Attention Pattern") 71# ax.axis("off") 72 73# @register_attn_figure_func 74# @save_matrix_wrapper(fmt="svg") 75# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 76# return attn_matrix 77 78# @register_attn_figure_func 79# @save_matrix_wrapper(fmt="svgz") 80# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 81# return attn_matrix
ATTENTION_MATRIX_FIGURE_FUNCS: list[typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]] =
[<function raw>]
def
register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]:
24def register_attn_figure_func( 25 func: AttentionMatrixFigureFunc, 26) -> AttentionMatrixFigureFunc: 27 """decorator for registering attention matrix figure function 28 29 if you want to add a new figure function, you should use this decorator 30 31 # Parameters: 32 - `func : AttentionMatrixFigureFunc` 33 your function, which should take an attention matrix and path 34 35 # Returns: 36 - `AttentionMatrixFigureFunc` 37 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 38 39 # Usage: 40 ```python 41 @register_attn_figure_func 42 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 43 fig, ax = plt.subplots(figsize=(10, 10)) 44 ax.matshow(attn_matrix, cmap="viridis") 45 ax.set_title("My New Figure Function") 46 ax.axis("off") 47 plt.savefig(path / "my_new_figure_func", format="svgz") 48 plt.close(fig) 49 ``` 50 51 """ 52 global ATTENTION_MATRIX_FIGURE_FUNCS 53 54 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 55 56 return func
decorator for registering attention matrix figure function
if you want to add a new figure function, you should use this decorator
# Parameters:
- `func : AttentionMatrixFigureFunc`
your function, which should take an attention matrix and path
# Returns:
- `AttentionMatrixFigureFunc`
your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
Usage:
@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
fig, ax = plt.subplots(figsize=(10, 10))
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("My New Figure Function")
ax.axis("off")
plt.savefig(path / "my_new_figure_func", format="svgz")
plt.close(fig)
@register_attn_figure_func
@save_matrix_wrapper(fmt='png')
def
raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']: