docs for pattern_lens v0.1.0
View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures


  1"""implements a bunch of types, default values, and templates which are useful for figure functions
  2
  3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
  4"""
  5
  6from pathlib import Path
  7from typing import Callable, Literal, overload, Union
  8import functools
  9import base64
 10import gzip
 11import io
 12
 13from PIL import Image
 14import numpy as np
 15from jaxtyping import Float, UInt8
 16import matplotlib
 17import matplotlib.pyplot as plt
 18from matplotlib.colors import Colormap
 19
 20from pattern_lens.consts import AttentionMatrix
 21
 22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
 23"Type alias for a function that, given an attention matrix, saves a figure"
 24
 25Matrix2D = Float[np.ndarray, "n m"]
 26"Type alias for a 2D matrix (plottable)"
 27
 28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
 29"Type alias for a 2D matrix with 3 channels (RGB)"
 30
 31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
 32"Type alias for a function that, given an attention matrix, returns a 2D matrix"
 33
 34MATPLOTLIB_FIGURE_FMT: str = "svgz"
 35"format for saving matplotlib figures"
 36
 37MatrixSaveFormat = Literal["png", "svg", "svgz"]
 38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
 39
 40MATRIX_SAVE_NORMALIZE: bool = False
 41"default for whether to normalize the matrix to range [0, 1]"
 42
 43MATRIX_SAVE_CMAP: str = "viridis"
 44"default colormap for saving matrices"
 45
 46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
 47"default format for saving matrices"
 48
 49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
 50"template for saving an `n` by `m` matrix as an svg/svgz"
 51
 52
 53@overload  # without keyword arguments, returns decorated function
 54def matplotlib_figure_saver(
 55    func: Callable[[AttentionMatrix, plt.Axes], None],
 56    *args,
 57    fmt: str = MATPLOTLIB_FIGURE_FMT,
 58) -> AttentionMatrixFigureFunc: ...
 59@overload  # with keyword arguments, returns decorator
 60def matplotlib_figure_saver(
 61    func: None = None,
 62    *args,
 63    fmt: str = MATPLOTLIB_FIGURE_FMT,
 64) -> Callable[
 65    [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 66]: ...
 67def matplotlib_figure_saver(
 68    func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 69    *args,
 70    fmt: str = MATPLOTLIB_FIGURE_FMT,
 71) -> Union[
 72    AttentionMatrixFigureFunc,
 73    Callable[
 74        [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 75    ],
 76]:
 77    """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 78
 79    # Parameters:
 80     - `func : Callable[[AttentionMatrix, plt.Axes], None]`
 81       your function, which should take an attention matrix and predefined `ax` object
 82     - `fmt : str`
 83       format for saving matplotlib figures
 84       (defaults to `MATPLOTLIB_FIGURE_FMT`)
 85
 86    # Returns:
 87     - `AttentionMatrixFigureFunc`
 88       your function, after we wrap it to save a figure
 89
 90    # Usage:
 91    ```python
 92    @register_attn_figure_func
 93    @matplotlib_figure_saver
 94    def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 95        ax.matshow(attn_matrix, cmap="viridis")
 96        ax.set_title("Raw Attention Pattern")
 97        ax.axis("off")
 98    ```
 99
100    """
101
102    assert len(args) == 0, "This decorator only supports keyword arguments"
103
104    def decorator(
105        func: Callable[[AttentionMatrix, plt.Axes], None],
106        fmt: str = fmt,
107    ) -> AttentionMatrixFigureFunc:
108        @functools.wraps(func)
109        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
110            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
111
112            fig, ax = plt.subplots(figsize=(10, 10))
113            func(attn_matrix, ax)
114            plt.tight_layout()
115            plt.savefig(fig_path)
116            plt.close(fig)
117
118        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
119
120        return wrapped
121
122    if callable(func):
123        # Handle no-arguments case
124        return decorator(func)
125    else:
126        # Handle arguments case
127        return decorator
128
129
130def matrix_to_image_preprocess(
131    matrix: Matrix2D,
132    normalize: bool = MATRIX_SAVE_NORMALIZE,
133    cmap: str | Colormap = MATRIX_SAVE_CMAP,
134) -> Matrix2Drgb:
135    """preprocess a 2D matrix into a plottable heatmap image
136
137    # Parameters:
138     - `matrix : Matrix2D`
139        input matrix
140     - `normalize : bool`
141        whether to normalize the matrix to range [0, 1]
142       (defaults to `MATRIX_SAVE_NORMALIZE`)
143     - `cmap : str|Colormap`
144        the colormap to use for the matrix
145       (defaults to `MATRIX_SAVE_CMAP`)
146
147    # Returns:
148     - `Matrix2Drgb`
149    """
150    # check dims
151    assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
152
153    # check matrix is not empty
154    assert matrix.size > 0, "Matrix cannot be empty"
155
156    # Normalize the matrix to range [0, 1]
157    normalized_matrix: Matrix2D
158    if normalize:
159        max_val, min_val = matrix.max(), matrix.min()
160        normalized_matrix = (matrix - min_val) / (max_val - min_val)
161    else:
162        assert (
163            matrix.min() >= 0 and matrix.max() <= 1
164        ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }"
165        normalized_matrix = matrix
166
167    # get the colormap
168    cmap_: Colormap
169    if isinstance(cmap, str):
170        cmap_ = matplotlib.colormaps[cmap]
171    elif isinstance(cmap, Colormap):
172        cmap_ = cmap
173    else:
174        raise TypeError(
175            f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
176        )
177
178    # Apply the viridis colormap
179    rgb_matrix: Float[np.ndarray, "n m channels=3"] = (  # noqa: F722
180        cmap_(normalized_matrix)[:, :, :3] * 255
181    ).astype(np.uint8)  # Drop alpha channel
182
183    assert rgb_matrix.shape == (
184        matrix.shape[0],
185        matrix.shape[1],
186        3,
187    ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
188
189    return rgb_matrix
190
191
192@overload
193def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
194@overload
195def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
196def matrix2drgb_to_png_bytes(
197    matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
198) -> bytes | None:
199    """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
200
201    - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
202    - if `buffer` is not provided, it will return the PNG bytes
203
204    # Parameters:
205     - `matrix : Matrix2Drgb`
206     - `buffer : io.BytesIO | None`
207       (defaults to `None`, in which case it will return the PNG bytes)
208
209    # Returns:
210     - `bytes|None`
211       `bytes` if `buffer` is `None`, otherwise `None`
212    """
213
214    pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
215    if buffer is None:
216        buffer = io.BytesIO()
217        pil_img.save(buffer, format="PNG")
218        buffer.seek(0)
219        return buffer.read()
220    else:
221        pil_img.save(buffer, format="PNG")
222        return None
223
224
225def matrix_as_svg(
226    matrix: Matrix2D,
227    normalize: bool = MATRIX_SAVE_NORMALIZE,
228    cmap=MATRIX_SAVE_CMAP,
229) -> str:
230    """quickly convert a 2D matrix to an SVG image, without matplotlib
231
232    # Parameters:
233     - `matrix : Float[np.ndarray, 'n m']`
234       a 2D matrix to convert to an SVG image
235     - `normalize : bool`
236       whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
237       (defaults to `False`)
238     - `cmap : str`
239       the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
240       (defaults to `"viridis"`)
241
242    # Returns:
243     - `str`
244       the SVG content for the matrix
245    """
246    # Get the dimensions of the matrix
247    m, n = matrix.shape
248
249    # Preprocess the matrix into an RGB image
250    matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
251        matrix, normalize=normalize, cmap=cmap
252    )
253
254    # Convert the RGB image to PNG bytes
255    image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
256
257    # Encode the PNG bytes as base64
258    png_base64: str = base64.b64encode(image_data).decode("utf-8")
259
260    # Generate the SVG content
261    svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
262
263    return svg_content
264
265
266@overload  # with keyword arguments, returns decorator
267def save_matrix_wrapper(
268    func: None = None,
269    *args,
270    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
271    normalize: bool = MATRIX_SAVE_NORMALIZE,
272    cmap: str = MATRIX_SAVE_CMAP,
273) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
274@overload  # without keyword arguments, returns decorated function
275def save_matrix_wrapper(
276    func: AttentionMatrixToMatrixFunc,
277    *args,
278    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
279    normalize: bool = MATRIX_SAVE_NORMALIZE,
280    cmap: str = MATRIX_SAVE_CMAP,
281) -> AttentionMatrixFigureFunc: ...
282def save_matrix_wrapper(
283    func: AttentionMatrixToMatrixFunc | None = None,
284    *args,
285    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
286    normalize: bool = MATRIX_SAVE_NORMALIZE,
287    cmap=MATRIX_SAVE_CMAP,
288) -> (
289    AttentionMatrixFigureFunc
290    | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
291):
292    """
293    Decorator for functions that process an attention matrix and save it as an SVGZ image.
294    Can handle both argumentless usage and with arguments.
295
296    # Parameters:
297
298     - `func : AttentionMatrixToMatrixFunc|None`
299        Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
300     - `fmt : MatrixSaveFormat, keyword-only`
301        The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
302     - `normalize : bool, keyword-only`
303        Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
304     - `cmap : str, keyword-only`
305        The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
306
307    # Returns:
308
309    `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
310
311    - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
312    - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
313
314    # Usage:
315
316    ```python
317    @save_matrix_wrapper
318    def identity_matrix(matrix):
319        return matrix
320
321    @save_matrix_wrapper(normalize=True, fmt="png")
322    def scale_matrix(matrix):
323        return matrix * 2
324
325    @save_matrix_wrapper(normalize=True, cmap="plasma")
326    def scale_matrix(matrix):
327        return matrix * 2
328
329    ```
330    """
331
332    assert len(args) == 0, "This decorator only supports keyword arguments"
333
334    assert (
335        fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
336    ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
337
338    def decorator(
339        func: Callable[[AttentionMatrix], Matrix2D],
340    ) -> AttentionMatrixFigureFunc:
341        @functools.wraps(func)
342        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
343            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
344            processed_matrix: Matrix2D = func(attn_matrix)
345
346            if fmt == "png":
347                processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
348                    processed_matrix,
349                    normalize=normalize,
350                    cmap=cmap,
351                )
352                image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
353                fig_path.write_bytes(image_data)
354
355            else:
356                svg_content: str = matrix_as_svg(
357                    processed_matrix, normalize=normalize, cmap=cmap
358                )
359
360                if fmt == "svgz":
361                    with gzip.open(fig_path, "wt") as f:
362                        f.write(svg_content)
363
364                else:
365                    fig_path.write_text(svg_content, encoding="utf-8")
366
367        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
368
369        return wrapped
370
371    if callable(func):
372        # Handle no-arguments case
373        return decorator(func)
374    else:
375        # Handle arguments case
376        return decorator

AttentionMatrixFigureFunc = typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]

Type alias for a function that, given an attention matrix, saves a figure

Matrix2D = <class 'jaxtyping.Float[ndarray, 'n m']'>

Type alias for a 2D matrix (plottable)

Matrix2Drgb = <class 'jaxtyping.UInt8[ndarray, 'n m rgb=3']'>

Type alias for a 2D matrix with 3 channels (RGB)

AttentionMatrixToMatrixFunc = typing.Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]

Type alias for a function that, given an attention matrix, returns a 2D matrix

MATPLOTLIB_FIGURE_FMT: str = 'svgz'

format for saving matplotlib figures

MatrixSaveFormat = typing.Literal['png', 'svg', 'svgz']

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

MATRIX_SAVE_NORMALIZE: bool = False

default for whether to normalize the matrix to range [0, 1]

MATRIX_SAVE_CMAP: str = 'viridis'

default colormap for saving matrices

MATRIX_SAVE_FMT: Literal['png', 'svg', 'svgz'] = 'svgz'

default format for saving matrices

MATRIX_SAVE_SVG_TEMPLATE: str = '<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>'

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver( func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType]] = None, *args, fmt: str = 'svgz') -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]:
 68def matplotlib_figure_saver(
 69    func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
 70    *args,
 71    fmt: str = MATPLOTLIB_FIGURE_FMT,
 72) -> Union[
 73    AttentionMatrixFigureFunc,
 74    Callable[
 75        [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
 76    ],
 77]:
 78    """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
 79
 80    # Parameters:
 81     - `func : Callable[[AttentionMatrix, plt.Axes], None]`
 82       your function, which should take an attention matrix and predefined `ax` object
 83     - `fmt : str`
 84       format for saving matplotlib figures
 85       (defaults to `MATPLOTLIB_FIGURE_FMT`)
 86
 87    # Returns:
 88     - `AttentionMatrixFigureFunc`
 89       your function, after we wrap it to save a figure
 90
 91    # Usage:
 92    ```python
 93    @register_attn_figure_func
 94    @matplotlib_figure_saver
 95    def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
 96        ax.matshow(attn_matrix, cmap="viridis")
 97        ax.set_title("Raw Attention Pattern")
 98        ax.axis("off")
 99    ```
100
101    """
102
103    assert len(args) == 0, "This decorator only supports keyword arguments"
104
105    def decorator(
106        func: Callable[[AttentionMatrix, plt.Axes], None],
107        fmt: str = fmt,
108    ) -> AttentionMatrixFigureFunc:
109        @functools.wraps(func)
110        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
111            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
112
113            fig, ax = plt.subplots(figsize=(10, 10))
114            func(attn_matrix, ax)
115            plt.tight_layout()
116            plt.savefig(fig_path)
117            plt.close(fig)
118
119        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
120
121        return wrapped
122
123    if callable(func):
124        # Handle no-arguments case
125        return decorator(func)
126    else:
127        # Handle arguments case
128        return decorator

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

  • func : Callable[[AttentionMatrix, plt.Axes], None] your function, which should take an attention matrix and predefined ax object
  • fmt : str format for saving matplotlib figures (defaults to MATPLOTLIB_FIGURE_FMT)

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("Raw Attention Pattern")
    ax.axis("off")
def matrix_to_image_preprocess( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap: str | matplotlib.colors.Colormap = 'viridis') -> jaxtyping.UInt8[ndarray, 'n m rgb=3']:
131def matrix_to_image_preprocess(
132    matrix: Matrix2D,
133    normalize: bool = MATRIX_SAVE_NORMALIZE,
134    cmap: str | Colormap = MATRIX_SAVE_CMAP,
135) -> Matrix2Drgb:
136    """preprocess a 2D matrix into a plottable heatmap image
137
138    # Parameters:
139     - `matrix : Matrix2D`
140        input matrix
141     - `normalize : bool`
142        whether to normalize the matrix to range [0, 1]
143       (defaults to `MATRIX_SAVE_NORMALIZE`)
144     - `cmap : str|Colormap`
145        the colormap to use for the matrix
146       (defaults to `MATRIX_SAVE_CMAP`)
147
148    # Returns:
149     - `Matrix2Drgb`
150    """
151    # check dims
152    assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
153
154    # check matrix is not empty
155    assert matrix.size > 0, "Matrix cannot be empty"
156
157    # Normalize the matrix to range [0, 1]
158    normalized_matrix: Matrix2D
159    if normalize:
160        max_val, min_val = matrix.max(), matrix.min()
161        normalized_matrix = (matrix - min_val) / (max_val - min_val)
162    else:
163        assert (
164            matrix.min() >= 0 and matrix.max() <= 1
165        ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }"
166        normalized_matrix = matrix
167
168    # get the colormap
169    cmap_: Colormap
170    if isinstance(cmap, str):
171        cmap_ = matplotlib.colormaps[cmap]
172    elif isinstance(cmap, Colormap):
173        cmap_ = cmap
174    else:
175        raise TypeError(
176            f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
177        )
178
179    # Apply the viridis colormap
180    rgb_matrix: Float[np.ndarray, "n m channels=3"] = (  # noqa: F722
181        cmap_(normalized_matrix)[:, :, :3] * 255
182    ).astype(np.uint8)  # Drop alpha channel
183
184    assert rgb_matrix.shape == (
185        matrix.shape[0],
186        matrix.shape[1],
187        3,
188    ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
189
190    return rgb_matrix

preprocess a 2D matrix into a plottable heatmap image

Parameters:

  • matrix : Matrix2D input matrix
  • normalize : bool whether to normalize the matrix to range [0, 1] (defaults to MATRIX_SAVE_NORMALIZE)
  • cmap : str|Colormap the colormap to use for the matrix (defaults to MATRIX_SAVE_CMAP)

Returns:

def matrix2drgb_to_png_bytes( matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'], buffer: _io.BytesIO | None = None) -> bytes | None:
197def matrix2drgb_to_png_bytes(
198    matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
199) -> bytes | None:
200    """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
201
202    - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
203    - if `buffer` is not provided, it will return the PNG bytes
204
205    # Parameters:
206     - `matrix : Matrix2Drgb`
207     - `buffer : io.BytesIO | None`
208       (defaults to `None`, in which case it will return the PNG bytes)
209
210    # Returns:
211     - `bytes|None`
212       `bytes` if `buffer` is `None`, otherwise `None`
213    """
214
215    pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
216    if buffer is None:
217        buffer = io.BytesIO()
218        pil_img.save(buffer, format="PNG")
219        buffer.seek(0)
220        return buffer.read()
221    else:
222        pil_img.save(buffer, format="PNG")
223        return None

Convert a Matrix2Drgb to valid PNG bytes via PIL

  • if buffer is provided, it will write the PNG bytes to the buffer and return None
  • if buffer is not provided, it will return the PNG bytes

Parameters:

  • matrix : Matrix2Drgb
  • buffer : io.BytesIO | None (defaults to None, in which case it will return the PNG bytes)

Returns:

  • bytes|None bytes if buffer is None, otherwise None
def matrix_as_svg( matrix: jaxtyping.Float[ndarray, 'n m'], normalize: bool = False, cmap='viridis') -> str:
226def matrix_as_svg(
227    matrix: Matrix2D,
228    normalize: bool = MATRIX_SAVE_NORMALIZE,
229    cmap=MATRIX_SAVE_CMAP,
230) -> str:
231    """quickly convert a 2D matrix to an SVG image, without matplotlib
232
233    # Parameters:
234     - `matrix : Float[np.ndarray, 'n m']`
235       a 2D matrix to convert to an SVG image
236     - `normalize : bool`
237       whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
238       (defaults to `False`)
239     - `cmap : str`
240       the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
241       (defaults to `"viridis"`)
242
243    # Returns:
244     - `str`
245       the SVG content for the matrix
246    """
247    # Get the dimensions of the matrix
248    m, n = matrix.shape
249
250    # Preprocess the matrix into an RGB image
251    matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
252        matrix, normalize=normalize, cmap=cmap
253    )
254
255    # Convert the RGB image to PNG bytes
256    image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
257
258    # Encode the PNG bytes as base64
259    png_base64: str = base64.b64encode(image_data).decode("utf-8")
260
261    # Generate the SVG content
262    svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
263
264    return svg_content

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

  • matrix : Float[np.ndarray, 'n m'] a 2D matrix to convert to an SVG image
  • normalize : bool whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be True or it will raise an AssertionError (defaults to False)
  • cmap : str the colormap to use for the matrix -- will look up in matplotlib.colormaps if it's a string (defaults to "viridis")

Returns:

  • str the SVG content for the matrix
def save_matrix_wrapper( func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]] = None, *args, fmt: Literal['png', 'svg', 'svgz'] = 'svgz', normalize: bool = False, cmap='viridis') -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]:
283def save_matrix_wrapper(
284    func: AttentionMatrixToMatrixFunc | None = None,
285    *args,
286    fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
287    normalize: bool = MATRIX_SAVE_NORMALIZE,
288    cmap=MATRIX_SAVE_CMAP,
289) -> (
290    AttentionMatrixFigureFunc
291    | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
292):
293    """
294    Decorator for functions that process an attention matrix and save it as an SVGZ image.
295    Can handle both argumentless usage and with arguments.
296
297    # Parameters:
298
299     - `func : AttentionMatrixToMatrixFunc|None`
300        Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
301     - `fmt : MatrixSaveFormat, keyword-only`
302        The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
303     - `normalize : bool, keyword-only`
304        Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
305     - `cmap : str, keyword-only`
306        The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
307
308    # Returns:
309
310    `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
311
312    - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
313    - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the  (with arguments case)
314
315    # Usage:
316
317    ```python
318    @save_matrix_wrapper
319    def identity_matrix(matrix):
320        return matrix
321
322    @save_matrix_wrapper(normalize=True, fmt="png")
323    def scale_matrix(matrix):
324        return matrix * 2
325
326    @save_matrix_wrapper(normalize=True, cmap="plasma")
327    def scale_matrix(matrix):
328        return matrix * 2
329
330    ```
331    """
332
333    assert len(args) == 0, "This decorator only supports keyword arguments"
334
335    assert (
336        fmt in MatrixSaveFormat.__args__  # type: ignore[attr-defined]
337    ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}"  # type: ignore[attr-defined]
338
339    def decorator(
340        func: Callable[[AttentionMatrix], Matrix2D],
341    ) -> AttentionMatrixFigureFunc:
342        @functools.wraps(func)
343        def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
344            fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
345            processed_matrix: Matrix2D = func(attn_matrix)
346
347            if fmt == "png":
348                processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
349                    processed_matrix,
350                    normalize=normalize,
351                    cmap=cmap,
352                )
353                image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
354                fig_path.write_bytes(image_data)
355
356            else:
357                svg_content: str = matrix_as_svg(
358                    processed_matrix, normalize=normalize, cmap=cmap
359                )
360
361                if fmt == "svgz":
362                    with gzip.open(fig_path, "wt") as f:
363                        f.write(svg_content)
364
365                else:
366                    fig_path.write_text(svg_content, encoding="utf-8")
367
368        wrapped.figure_save_fmt = fmt  # type: ignore[attr-defined]
369
370        return wrapped
371
372    if callable(func):
373        # Handle no-arguments case
374        return decorator(func)
375    else:
376        # Handle arguments case
377        return decorator

Decorator for functions that process an attention matrix and save it as an SVGZ image. Can handle both argumentless usage and with arguments.

Parameters:

  • func : AttentionMatrixToMatrixFunc|None Either the function to decorate (in the no-arguments case) or None when used with arguments.
  • fmt : MatrixSaveFormat, keyword-only The format to save the matrix as. Defaults to MATRIX_SAVE_FMT.
  • normalize : bool, keyword-only Whether to normalize the matrix to range [0, 1]. Defaults to False.
  • cmap : str, keyword-only The colormap to use for the matrix. Defaults to MATRIX_SVG_CMAP.

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
    return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
    return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
    return matrix * 2