docs for pattern_lens v0.1.0
View Source on GitHub

pattern_lens.attn_figure_funcs

default figure functions

  • If you are making a PR, add your new figure function here.
  • if you are using this as a library, then you can see examples here

note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS


 1"""default figure functions
 2
 3- If you are making a PR, add your new figure function here.
 4- if you are using this as a library, then you can see examples here
 5
 6
 7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator
 8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS`
 9
10"""
11
12from pattern_lens.consts import AttentionMatrix
13from pattern_lens.figure_util import (
14    AttentionMatrixFigureFunc,
15    save_matrix_wrapper,
16    Matrix2D,
17)
18
19
20ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list()
21
22
23def register_attn_figure_func(
24    func: AttentionMatrixFigureFunc,
25) -> AttentionMatrixFigureFunc:
26    """decorator for registering attention matrix figure function
27
28    if you want to add a new figure function, you should use this decorator
29
30        # Parameters:
31         - `func : AttentionMatrixFigureFunc`
32           your function, which should take an attention matrix and path
33
34        # Returns:
35         - `AttentionMatrixFigureFunc`
36           your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
37
38    # Usage:
39    ```python
40    @register_attn_figure_func
41    def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
42        fig, ax = plt.subplots(figsize=(10, 10))
43        ax.matshow(attn_matrix, cmap="viridis")
44        ax.set_title("My New Figure Function")
45        ax.axis("off")
46        plt.savefig(path / "my_new_figure_func", format="svgz")
47        plt.close(fig)
48    ```
49
50    """
51    global ATTENTION_MATRIX_FIGURE_FUNCS
52
53    ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
54
55    return func
56
57
58@register_attn_figure_func
59@save_matrix_wrapper(fmt="png")
60def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
61    return attn_matrix
62
63
64# some more examples:
65
66# @register_attn_figure_func
67# @matplotlib_figure_saver
68# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
69#     ax.matshow(attn_matrix, cmap="viridis")
70#     ax.set_title("Raw Attention Pattern")
71#     ax.axis("off")
72
73# @register_attn_figure_func
74# @save_matrix_wrapper(fmt="svg")
75# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D:
76#     return attn_matrix
77
78# @register_attn_figure_func
79# @save_matrix_wrapper(fmt="svgz")
80# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D:
81#     return attn_matrix

ATTENTION_MATRIX_FIGURE_FUNCS: list[typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]] = [<function raw>]
def register_attn_figure_func( func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]:
24def register_attn_figure_func(
25    func: AttentionMatrixFigureFunc,
26) -> AttentionMatrixFigureFunc:
27    """decorator for registering attention matrix figure function
28
29    if you want to add a new figure function, you should use this decorator
30
31        # Parameters:
32         - `func : AttentionMatrixFigureFunc`
33           your function, which should take an attention matrix and path
34
35        # Returns:
36         - `AttentionMatrixFigureFunc`
37           your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`
38
39    # Usage:
40    ```python
41    @register_attn_figure_func
42    def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
43        fig, ax = plt.subplots(figsize=(10, 10))
44        ax.matshow(attn_matrix, cmap="viridis")
45        ax.set_title("My New Figure Function")
46        ax.axis("off")
47        plt.savefig(path / "my_new_figure_func", format="svgz")
48        plt.close(fig)
49    ```
50
51    """
52    global ATTENTION_MATRIX_FIGURE_FUNCS
53
54    ATTENTION_MATRIX_FIGURE_FUNCS.append(func)
55
56    return func

decorator for registering attention matrix figure function

if you want to add a new figure function, you should use this decorator

# Parameters:
 - `func : AttentionMatrixFigureFunc`
   your function, which should take an attention matrix and path

# Returns:
 - `AttentionMatrixFigureFunc`
   your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`

Usage:

@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("My New Figure Function")
    ax.axis("off")
    plt.savefig(path / "my_new_figure_func", format="svgz")
    plt.close(fig)
@register_attn_figure_func
@save_matrix_wrapper(fmt='png')
def raw( attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']) -> jaxtyping.Float[ndarray, 'n m']:
59@register_attn_figure_func
60@save_matrix_wrapper(fmt="png")
61def raw(attn_matrix: AttentionMatrix) -> Matrix2D:
62    return attn_matrix