Coverage for pattern_lens\figure_util.py: 94%
115 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1"""implements a bunch of types, default values, and templates which are useful for figure functions
3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures
4"""
6from pathlib import Path
7from typing import Callable, Literal, overload, Union
8import functools
9import base64
10import gzip
11import io
13from PIL import Image
14import numpy as np
15from jaxtyping import Float, UInt8
16import matplotlib
17import matplotlib.pyplot as plt
18from matplotlib.colors import Colormap
20from pattern_lens.consts import AttentionMatrix
22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None]
23"Type alias for a function that, given an attention matrix, saves a figure"
25Matrix2D = Float[np.ndarray, "n m"]
26"Type alias for a 2D matrix (plottable)"
28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"]
29"Type alias for a 2D matrix with 3 channels (RGB)"
31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D]
32"Type alias for a function that, given an attention matrix, returns a 2D matrix"
34MATPLOTLIB_FIGURE_FMT: str = "svgz"
35"format for saving matplotlib figures"
37MatrixSaveFormat = Literal["png", "svg", "svgz"]
38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure"
40MATRIX_SAVE_NORMALIZE: bool = False
41"default for whether to normalize the matrix to range [0, 1]"
43MATRIX_SAVE_CMAP: str = "viridis"
44"default colormap for saving matrices"
46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz"
47"default format for saving matrices"
49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>"""
50"template for saving an `n` by `m` matrix as an svg/svgz"
53@overload # without keyword arguments, returns decorated function
54def matplotlib_figure_saver(
55 func: Callable[[AttentionMatrix, plt.Axes], None],
56 *args,
57 fmt: str = MATPLOTLIB_FIGURE_FMT,
58) -> AttentionMatrixFigureFunc: ...
59@overload # with keyword arguments, returns decorator
60def matplotlib_figure_saver(
61 func: None = None,
62 *args,
63 fmt: str = MATPLOTLIB_FIGURE_FMT,
64) -> Callable[
65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
66]: ...
67def matplotlib_figure_saver(
68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None,
69 *args,
70 fmt: str = MATPLOTLIB_FIGURE_FMT,
71) -> Union[
72 AttentionMatrixFigureFunc,
73 Callable[
74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc
75 ],
76]:
77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure
79 # Parameters:
80 - `func : Callable[[AttentionMatrix, plt.Axes], None]`
81 your function, which should take an attention matrix and predefined `ax` object
82 - `fmt : str`
83 format for saving matplotlib figures
84 (defaults to `MATPLOTLIB_FIGURE_FMT`)
86 # Returns:
87 - `AttentionMatrixFigureFunc`
88 your function, after we wrap it to save a figure
90 # Usage:
91 ```python
92 @register_attn_figure_func
93 @matplotlib_figure_saver
94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
95 ax.matshow(attn_matrix, cmap="viridis")
96 ax.set_title("Raw Attention Pattern")
97 ax.axis("off")
98 ```
100 """
102 assert len(args) == 0, "This decorator only supports keyword arguments"
104 def decorator(
105 func: Callable[[AttentionMatrix, plt.Axes], None],
106 fmt: str = fmt,
107 ) -> AttentionMatrixFigureFunc:
108 @functools.wraps(func)
109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
112 fig, ax = plt.subplots(figsize=(10, 10))
113 func(attn_matrix, ax)
114 plt.tight_layout()
115 plt.savefig(fig_path)
116 plt.close(fig)
118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
120 return wrapped
122 if callable(func):
123 # Handle no-arguments case
124 return decorator(func)
125 else:
126 # Handle arguments case
127 return decorator
130def matrix_to_image_preprocess(
131 matrix: Matrix2D,
132 normalize: bool = MATRIX_SAVE_NORMALIZE,
133 cmap: str | Colormap = MATRIX_SAVE_CMAP,
134) -> Matrix2Drgb:
135 """preprocess a 2D matrix into a plottable heatmap image
137 # Parameters:
138 - `matrix : Matrix2D`
139 input matrix
140 - `normalize : bool`
141 whether to normalize the matrix to range [0, 1]
142 (defaults to `MATRIX_SAVE_NORMALIZE`)
143 - `cmap : str|Colormap`
144 the colormap to use for the matrix
145 (defaults to `MATRIX_SAVE_CMAP`)
147 # Returns:
148 - `Matrix2Drgb`
149 """
150 # check dims
151 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }"
153 # check matrix is not empty
154 assert matrix.size > 0, "Matrix cannot be empty"
156 # Normalize the matrix to range [0, 1]
157 normalized_matrix: Matrix2D
158 if normalize:
159 max_val, min_val = matrix.max(), matrix.min()
160 normalized_matrix = (matrix - min_val) / (max_val - min_val)
161 else:
162 assert (
163 matrix.min() >= 0 and matrix.max() <= 1
164 ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }"
165 normalized_matrix = matrix
167 # get the colormap
168 cmap_: Colormap
169 if isinstance(cmap, str):
170 cmap_ = matplotlib.colormaps[cmap]
171 elif isinstance(cmap, Colormap):
172 cmap_ = cmap
173 else:
174 raise TypeError(
175 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap"
176 )
178 # Apply the viridis colormap
179 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722
180 cmap_(normalized_matrix)[:, :, :3] * 255
181 ).astype(np.uint8) # Drop alpha channel
183 assert rgb_matrix.shape == (
184 matrix.shape[0],
185 matrix.shape[1],
186 3,
187 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }"
189 return rgb_matrix
192@overload
193def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ...
194@overload
195def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ...
196def matrix2drgb_to_png_bytes(
197 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None
198) -> bytes | None:
199 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL
201 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None`
202 - if `buffer` is not provided, it will return the PNG bytes
204 # Parameters:
205 - `matrix : Matrix2Drgb`
206 - `buffer : io.BytesIO | None`
207 (defaults to `None`, in which case it will return the PNG bytes)
209 # Returns:
210 - `bytes|None`
211 `bytes` if `buffer` is `None`, otherwise `None`
212 """
214 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB")
215 if buffer is None:
216 buffer = io.BytesIO()
217 pil_img.save(buffer, format="PNG")
218 buffer.seek(0)
219 return buffer.read()
220 else:
221 pil_img.save(buffer, format="PNG")
222 return None
225def matrix_as_svg(
226 matrix: Matrix2D,
227 normalize: bool = MATRIX_SAVE_NORMALIZE,
228 cmap=MATRIX_SAVE_CMAP,
229) -> str:
230 """quickly convert a 2D matrix to an SVG image, without matplotlib
232 # Parameters:
233 - `matrix : Float[np.ndarray, 'n m']`
234 a 2D matrix to convert to an SVG image
235 - `normalize : bool`
236 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError`
237 (defaults to `False`)
238 - `cmap : str`
239 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string
240 (defaults to `"viridis"`)
242 # Returns:
243 - `str`
244 the SVG content for the matrix
245 """
246 # Get the dimensions of the matrix
247 m, n = matrix.shape
249 # Preprocess the matrix into an RGB image
250 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
251 matrix, normalize=normalize, cmap=cmap
252 )
254 # Convert the RGB image to PNG bytes
255 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb)
257 # Encode the PNG bytes as base64
258 png_base64: str = base64.b64encode(image_data).decode("utf-8")
260 # Generate the SVG content
261 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64)
263 return svg_content
266@overload # with keyword arguments, returns decorator
267def save_matrix_wrapper(
268 func: None = None,
269 *args,
270 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
271 normalize: bool = MATRIX_SAVE_NORMALIZE,
272 cmap: str = MATRIX_SAVE_CMAP,
273) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ...
274@overload # without keyword arguments, returns decorated function
275def save_matrix_wrapper(
276 func: AttentionMatrixToMatrixFunc,
277 *args,
278 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
279 normalize: bool = MATRIX_SAVE_NORMALIZE,
280 cmap: str = MATRIX_SAVE_CMAP,
281) -> AttentionMatrixFigureFunc: ...
282def save_matrix_wrapper(
283 func: AttentionMatrixToMatrixFunc | None = None,
284 *args,
285 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT,
286 normalize: bool = MATRIX_SAVE_NORMALIZE,
287 cmap=MATRIX_SAVE_CMAP,
288) -> (
289 AttentionMatrixFigureFunc
290 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
291):
292 """
293 Decorator for functions that process an attention matrix and save it as an SVGZ image.
294 Can handle both argumentless usage and with arguments.
296 # Parameters:
298 - `func : AttentionMatrixToMatrixFunc|None`
299 Either the function to decorate (in the no-arguments case) or `None` when used with arguments.
300 - `fmt : MatrixSaveFormat, keyword-only`
301 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`.
302 - `normalize : bool, keyword-only`
303 Whether to normalize the matrix to range [0, 1]. Defaults to `False`.
304 - `cmap : str, keyword-only`
305 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`.
307 # Returns:
309 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]`
311 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case)
312 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case)
314 # Usage:
316 ```python
317 @save_matrix_wrapper
318 def identity_matrix(matrix):
319 return matrix
321 @save_matrix_wrapper(normalize=True, fmt="png")
322 def scale_matrix(matrix):
323 return matrix * 2
325 @save_matrix_wrapper(normalize=True, cmap="plasma")
326 def scale_matrix(matrix):
327 return matrix * 2
329 ```
330 """
332 assert len(args) == 0, "This decorator only supports keyword arguments"
334 assert (
335 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined]
336 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined]
338 def decorator(
339 func: Callable[[AttentionMatrix], Matrix2D],
340 ) -> AttentionMatrixFigureFunc:
341 @functools.wraps(func)
342 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None:
343 fig_path: Path = save_dir / f"{func.__name__}.{fmt}"
344 processed_matrix: Matrix2D = func(attn_matrix)
346 if fmt == "png":
347 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess(
348 processed_matrix,
349 normalize=normalize,
350 cmap=cmap,
351 )
352 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb)
353 fig_path.write_bytes(image_data)
355 else:
356 svg_content: str = matrix_as_svg(
357 processed_matrix, normalize=normalize, cmap=cmap
358 )
360 if fmt == "svgz":
361 with gzip.open(fig_path, "wt") as f:
362 f.write(svg_content)
364 else:
365 fig_path.write_text(svg_content, encoding="utf-8")
367 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined]
369 return wrapped
371 if callable(func):
372 # Handle no-arguments case
373 return decorator(func)
374 else:
375 # Handle arguments case
376 return decorator