pattern_lens.figure_util
implements a bunch of types, default values, and templates which are useful for figure functions
notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures
1"""implements a bunch of types, default values, and templates which are useful for figure functions 2 3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 4""" 5 6from pathlib import Path 7from typing import Callable, Literal, overload, Union 8import functools 9import base64 10import gzip 11import io 12 13from PIL import Image 14import numpy as np 15from jaxtyping import Float, UInt8 16import matplotlib 17import matplotlib.pyplot as plt 18from matplotlib.colors import Colormap 19 20from pattern_lens.consts import AttentionMatrix 21 22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 23"Type alias for a function that, given an attention matrix, saves a figure" 24 25Matrix2D = Float[np.ndarray, "n m"] 26"Type alias for a 2D matrix (plottable)" 27 28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 29"Type alias for a 2D matrix with 3 channels (RGB)" 30 31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 32"Type alias for a function that, given an attention matrix, returns a 2D matrix" 33 34MATPLOTLIB_FIGURE_FMT: str = "svgz" 35"format for saving matplotlib figures" 36 37MatrixSaveFormat = Literal["png", "svg", "svgz"] 38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 39 40MATRIX_SAVE_NORMALIZE: bool = False 41"default for whether to normalize the matrix to range [0, 1]" 42 43MATRIX_SAVE_CMAP: str = "viridis" 44"default colormap for saving matrices" 45 46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 47"default format for saving matrices" 48 49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 50"template for saving an `n` by `m` matrix as an svg/svgz" 51 52 53@overload # without keyword arguments, returns decorated function 54def matplotlib_figure_saver( 55 func: Callable[[AttentionMatrix, plt.Axes], None], 56 *args, 57 fmt: str = MATPLOTLIB_FIGURE_FMT, 58) -> AttentionMatrixFigureFunc: ... 59@overload # with keyword arguments, returns decorator 60def matplotlib_figure_saver( 61 func: None = None, 62 *args, 63 fmt: str = MATPLOTLIB_FIGURE_FMT, 64) -> Callable[ 65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 66]: ... 67def matplotlib_figure_saver( 68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 69 *args, 70 fmt: str = MATPLOTLIB_FIGURE_FMT, 71) -> Union[ 72 AttentionMatrixFigureFunc, 73 Callable[ 74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 75 ], 76]: 77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 78 79 # Parameters: 80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 81 your function, which should take an attention matrix and predefined `ax` object 82 - `fmt : str` 83 format for saving matplotlib figures 84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 85 86 # Returns: 87 - `AttentionMatrixFigureFunc` 88 your function, after we wrap it to save a figure 89 90 # Usage: 91 ```python 92 @register_attn_figure_func 93 @matplotlib_figure_saver 94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 95 ax.matshow(attn_matrix, cmap="viridis") 96 ax.set_title("Raw Attention Pattern") 97 ax.axis("off") 98 ``` 99 100 """ 101 102 assert len(args) == 0, "This decorator only supports keyword arguments" 103 104 def decorator( 105 func: Callable[[AttentionMatrix, plt.Axes], None], 106 fmt: str = fmt, 107 ) -> AttentionMatrixFigureFunc: 108 @functools.wraps(func) 109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 111 112 fig, ax = plt.subplots(figsize=(10, 10)) 113 func(attn_matrix, ax) 114 plt.tight_layout() 115 plt.savefig(fig_path) 116 plt.close(fig) 117 118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 119 120 return wrapped 121 122 if callable(func): 123 # Handle no-arguments case 124 return decorator(func) 125 else: 126 # Handle arguments case 127 return decorator 128 129 130def matrix_to_image_preprocess( 131 matrix: Matrix2D, 132 normalize: bool = MATRIX_SAVE_NORMALIZE, 133 cmap: str | Colormap = MATRIX_SAVE_CMAP, 134) -> Matrix2Drgb: 135 """preprocess a 2D matrix into a plottable heatmap image 136 137 # Parameters: 138 - `matrix : Matrix2D` 139 input matrix 140 - `normalize : bool` 141 whether to normalize the matrix to range [0, 1] 142 (defaults to `MATRIX_SAVE_NORMALIZE`) 143 - `cmap : str|Colormap` 144 the colormap to use for the matrix 145 (defaults to `MATRIX_SAVE_CMAP`) 146 147 # Returns: 148 - `Matrix2Drgb` 149 """ 150 # check dims 151 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 152 153 # check matrix is not empty 154 assert matrix.size > 0, "Matrix cannot be empty" 155 156 # Normalize the matrix to range [0, 1] 157 normalized_matrix: Matrix2D 158 if normalize: 159 max_val, min_val = matrix.max(), matrix.min() 160 normalized_matrix = (matrix - min_val) / (max_val - min_val) 161 else: 162 assert ( 163 matrix.min() >= 0 and matrix.max() <= 1 164 ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }" 165 normalized_matrix = matrix 166 167 # get the colormap 168 cmap_: Colormap 169 if isinstance(cmap, str): 170 cmap_ = matplotlib.colormaps[cmap] 171 elif isinstance(cmap, Colormap): 172 cmap_ = cmap 173 else: 174 raise TypeError( 175 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 176 ) 177 178 # Apply the viridis colormap 179 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 180 cmap_(normalized_matrix)[:, :, :3] * 255 181 ).astype(np.uint8) # Drop alpha channel 182 183 assert rgb_matrix.shape == ( 184 matrix.shape[0], 185 matrix.shape[1], 186 3, 187 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 188 189 return rgb_matrix 190 191 192@overload 193def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 194@overload 195def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 196def matrix2drgb_to_png_bytes( 197 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 198) -> bytes | None: 199 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 200 201 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 202 - if `buffer` is not provided, it will return the PNG bytes 203 204 # Parameters: 205 - `matrix : Matrix2Drgb` 206 - `buffer : io.BytesIO | None` 207 (defaults to `None`, in which case it will return the PNG bytes) 208 209 # Returns: 210 - `bytes|None` 211 `bytes` if `buffer` is `None`, otherwise `None` 212 """ 213 214 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 215 if buffer is None: 216 buffer = io.BytesIO() 217 pil_img.save(buffer, format="PNG") 218 buffer.seek(0) 219 return buffer.read() 220 else: 221 pil_img.save(buffer, format="PNG") 222 return None 223 224 225def matrix_as_svg( 226 matrix: Matrix2D, 227 normalize: bool = MATRIX_SAVE_NORMALIZE, 228 cmap=MATRIX_SAVE_CMAP, 229) -> str: 230 """quickly convert a 2D matrix to an SVG image, without matplotlib 231 232 # Parameters: 233 - `matrix : Float[np.ndarray, 'n m']` 234 a 2D matrix to convert to an SVG image 235 - `normalize : bool` 236 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 237 (defaults to `False`) 238 - `cmap : str` 239 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 240 (defaults to `"viridis"`) 241 242 # Returns: 243 - `str` 244 the SVG content for the matrix 245 """ 246 # Get the dimensions of the matrix 247 m, n = matrix.shape 248 249 # Preprocess the matrix into an RGB image 250 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 251 matrix, normalize=normalize, cmap=cmap 252 ) 253 254 # Convert the RGB image to PNG bytes 255 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 256 257 # Encode the PNG bytes as base64 258 png_base64: str = base64.b64encode(image_data).decode("utf-8") 259 260 # Generate the SVG content 261 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 262 263 return svg_content 264 265 266@overload # with keyword arguments, returns decorator 267def save_matrix_wrapper( 268 func: None = None, 269 *args, 270 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 271 normalize: bool = MATRIX_SAVE_NORMALIZE, 272 cmap: str = MATRIX_SAVE_CMAP, 273) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 274@overload # without keyword arguments, returns decorated function 275def save_matrix_wrapper( 276 func: AttentionMatrixToMatrixFunc, 277 *args, 278 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 279 normalize: bool = MATRIX_SAVE_NORMALIZE, 280 cmap: str = MATRIX_SAVE_CMAP, 281) -> AttentionMatrixFigureFunc: ... 282def save_matrix_wrapper( 283 func: AttentionMatrixToMatrixFunc | None = None, 284 *args, 285 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 286 normalize: bool = MATRIX_SAVE_NORMALIZE, 287 cmap=MATRIX_SAVE_CMAP, 288) -> ( 289 AttentionMatrixFigureFunc 290 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 291): 292 """ 293 Decorator for functions that process an attention matrix and save it as an SVGZ image. 294 Can handle both argumentless usage and with arguments. 295 296 # Parameters: 297 298 - `func : AttentionMatrixToMatrixFunc|None` 299 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 300 - `fmt : MatrixSaveFormat, keyword-only` 301 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 302 - `normalize : bool, keyword-only` 303 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 304 - `cmap : str, keyword-only` 305 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 306 307 # Returns: 308 309 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 310 311 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 312 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 313 314 # Usage: 315 316 ```python 317 @save_matrix_wrapper 318 def identity_matrix(matrix): 319 return matrix 320 321 @save_matrix_wrapper(normalize=True, fmt="png") 322 def scale_matrix(matrix): 323 return matrix * 2 324 325 @save_matrix_wrapper(normalize=True, cmap="plasma") 326 def scale_matrix(matrix): 327 return matrix * 2 328 329 ``` 330 """ 331 332 assert len(args) == 0, "This decorator only supports keyword arguments" 333 334 assert ( 335 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 336 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 337 338 def decorator( 339 func: Callable[[AttentionMatrix], Matrix2D], 340 ) -> AttentionMatrixFigureFunc: 341 @functools.wraps(func) 342 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 343 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 344 processed_matrix: Matrix2D = func(attn_matrix) 345 346 if fmt == "png": 347 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 348 processed_matrix, 349 normalize=normalize, 350 cmap=cmap, 351 ) 352 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 353 fig_path.write_bytes(image_data) 354 355 else: 356 svg_content: str = matrix_as_svg( 357 processed_matrix, normalize=normalize, cmap=cmap 358 ) 359 360 if fmt == "svgz": 361 with gzip.open(fig_path, "wt") as f: 362 f.write(svg_content) 363 364 else: 365 fig_path.write_text(svg_content, encoding="utf-8") 366 367 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 368 369 return wrapped 370 371 if callable(func): 372 # Handle no-arguments case 373 return decorator(func) 374 else: 375 # Handle arguments case 376 return decorator
Type alias for a function that, given an attention matrix, saves a figure
Type alias for a 2D matrix (plottable)
Type alias for a 2D matrix with 3 channels (RGB)
Type alias for a function that, given an attention matrix, returns a 2D matrix
format for saving matplotlib figures
Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure
default for whether to normalize the matrix to range [0, 1]
default colormap for saving matrices
default format for saving matrices
template for saving an n by m matrix as an svg/svgz
68def matplotlib_figure_saver( 69 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 70 *args, 71 fmt: str = MATPLOTLIB_FIGURE_FMT, 72) -> Union[ 73 AttentionMatrixFigureFunc, 74 Callable[ 75 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 76 ], 77]: 78 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 79 80 # Parameters: 81 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 82 your function, which should take an attention matrix and predefined `ax` object 83 - `fmt : str` 84 format for saving matplotlib figures 85 (defaults to `MATPLOTLIB_FIGURE_FMT`) 86 87 # Returns: 88 - `AttentionMatrixFigureFunc` 89 your function, after we wrap it to save a figure 90 91 # Usage: 92 ```python 93 @register_attn_figure_func 94 @matplotlib_figure_saver 95 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 96 ax.matshow(attn_matrix, cmap="viridis") 97 ax.set_title("Raw Attention Pattern") 98 ax.axis("off") 99 ``` 100 101 """ 102 103 assert len(args) == 0, "This decorator only supports keyword arguments" 104 105 def decorator( 106 func: Callable[[AttentionMatrix, plt.Axes], None], 107 fmt: str = fmt, 108 ) -> AttentionMatrixFigureFunc: 109 @functools.wraps(func) 110 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 111 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 112 113 fig, ax = plt.subplots(figsize=(10, 10)) 114 func(attn_matrix, ax) 115 plt.tight_layout() 116 plt.savefig(fig_path) 117 plt.close(fig) 118 119 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 120 121 return wrapped 122 123 if callable(func): 124 # Handle no-arguments case 125 return decorator(func) 126 else: 127 # Handle arguments case 128 return decorator
decorator for functions which take an attention matrix and predefined ax object, making it save a figure
Parameters:
func : Callable[[AttentionMatrix, plt.Axes], None]your function, which should take an attention matrix and predefinedaxobjectfmt : strformat for saving matplotlib figures (defaults toMATPLOTLIB_FIGURE_FMT)
Returns:
AttentionMatrixFigureFuncyour function, after we wrap it to save a figure
Usage:
@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
ax.matshow(attn_matrix, cmap="viridis")
ax.set_title("Raw Attention Pattern")
ax.axis("off")
131def matrix_to_image_preprocess( 132 matrix: Matrix2D, 133 normalize: bool = MATRIX_SAVE_NORMALIZE, 134 cmap: str | Colormap = MATRIX_SAVE_CMAP, 135) -> Matrix2Drgb: 136 """preprocess a 2D matrix into a plottable heatmap image 137 138 # Parameters: 139 - `matrix : Matrix2D` 140 input matrix 141 - `normalize : bool` 142 whether to normalize the matrix to range [0, 1] 143 (defaults to `MATRIX_SAVE_NORMALIZE`) 144 - `cmap : str|Colormap` 145 the colormap to use for the matrix 146 (defaults to `MATRIX_SAVE_CMAP`) 147 148 # Returns: 149 - `Matrix2Drgb` 150 """ 151 # check dims 152 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 153 154 # check matrix is not empty 155 assert matrix.size > 0, "Matrix cannot be empty" 156 157 # Normalize the matrix to range [0, 1] 158 normalized_matrix: Matrix2D 159 if normalize: 160 max_val, min_val = matrix.max(), matrix.min() 161 normalized_matrix = (matrix - min_val) / (max_val - min_val) 162 else: 163 assert ( 164 matrix.min() >= 0 and matrix.max() <= 1 165 ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }" 166 normalized_matrix = matrix 167 168 # get the colormap 169 cmap_: Colormap 170 if isinstance(cmap, str): 171 cmap_ = matplotlib.colormaps[cmap] 172 elif isinstance(cmap, Colormap): 173 cmap_ = cmap 174 else: 175 raise TypeError( 176 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 177 ) 178 179 # Apply the viridis colormap 180 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 181 cmap_(normalized_matrix)[:, :, :3] * 255 182 ).astype(np.uint8) # Drop alpha channel 183 184 assert rgb_matrix.shape == ( 185 matrix.shape[0], 186 matrix.shape[1], 187 3, 188 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 189 190 return rgb_matrix
preprocess a 2D matrix into a plottable heatmap image
Parameters:
matrix : Matrix2Dinput matrixnormalize : boolwhether to normalize the matrix to range [0, 1] (defaults toMATRIX_SAVE_NORMALIZE)cmap : str|Colormapthe colormap to use for the matrix (defaults toMATRIX_SAVE_CMAP)
Returns:
197def matrix2drgb_to_png_bytes( 198 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 199) -> bytes | None: 200 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 201 202 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 203 - if `buffer` is not provided, it will return the PNG bytes 204 205 # Parameters: 206 - `matrix : Matrix2Drgb` 207 - `buffer : io.BytesIO | None` 208 (defaults to `None`, in which case it will return the PNG bytes) 209 210 # Returns: 211 - `bytes|None` 212 `bytes` if `buffer` is `None`, otherwise `None` 213 """ 214 215 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 216 if buffer is None: 217 buffer = io.BytesIO() 218 pil_img.save(buffer, format="PNG") 219 buffer.seek(0) 220 return buffer.read() 221 else: 222 pil_img.save(buffer, format="PNG") 223 return None
Convert a Matrix2Drgb to valid PNG bytes via PIL
- if
bufferis provided, it will write the PNG bytes to the buffer and returnNone - if
bufferis not provided, it will return the PNG bytes
Parameters:
matrix : Matrix2Drgbbuffer : io.BytesIO | None(defaults toNone, in which case it will return the PNG bytes)
Returns:
bytes|NonebytesifbufferisNone, otherwiseNone
226def matrix_as_svg( 227 matrix: Matrix2D, 228 normalize: bool = MATRIX_SAVE_NORMALIZE, 229 cmap=MATRIX_SAVE_CMAP, 230) -> str: 231 """quickly convert a 2D matrix to an SVG image, without matplotlib 232 233 # Parameters: 234 - `matrix : Float[np.ndarray, 'n m']` 235 a 2D matrix to convert to an SVG image 236 - `normalize : bool` 237 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 238 (defaults to `False`) 239 - `cmap : str` 240 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 241 (defaults to `"viridis"`) 242 243 # Returns: 244 - `str` 245 the SVG content for the matrix 246 """ 247 # Get the dimensions of the matrix 248 m, n = matrix.shape 249 250 # Preprocess the matrix into an RGB image 251 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 252 matrix, normalize=normalize, cmap=cmap 253 ) 254 255 # Convert the RGB image to PNG bytes 256 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 257 258 # Encode the PNG bytes as base64 259 png_base64: str = base64.b64encode(image_data).decode("utf-8") 260 261 # Generate the SVG content 262 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 263 264 return svg_content
quickly convert a 2D matrix to an SVG image, without matplotlib
Parameters:
matrix : Float[np.ndarray, 'n m']a 2D matrix to convert to an SVG imagenormalize : boolwhether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must beTrueor it will raise anAssertionError(defaults toFalse)cmap : strthe colormap to use for the matrix -- will look up inmatplotlib.colormapsif it's a string (defaults to"viridis")
Returns:
strthe SVG content for the matrix
283def save_matrix_wrapper( 284 func: AttentionMatrixToMatrixFunc | None = None, 285 *args, 286 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 287 normalize: bool = MATRIX_SAVE_NORMALIZE, 288 cmap=MATRIX_SAVE_CMAP, 289) -> ( 290 AttentionMatrixFigureFunc 291 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 292): 293 """ 294 Decorator for functions that process an attention matrix and save it as an SVGZ image. 295 Can handle both argumentless usage and with arguments. 296 297 # Parameters: 298 299 - `func : AttentionMatrixToMatrixFunc|None` 300 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 301 - `fmt : MatrixSaveFormat, keyword-only` 302 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 303 - `normalize : bool, keyword-only` 304 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 305 - `cmap : str, keyword-only` 306 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 307 308 # Returns: 309 310 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 311 312 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 313 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 314 315 # Usage: 316 317 ```python 318 @save_matrix_wrapper 319 def identity_matrix(matrix): 320 return matrix 321 322 @save_matrix_wrapper(normalize=True, fmt="png") 323 def scale_matrix(matrix): 324 return matrix * 2 325 326 @save_matrix_wrapper(normalize=True, cmap="plasma") 327 def scale_matrix(matrix): 328 return matrix * 2 329 330 ``` 331 """ 332 333 assert len(args) == 0, "This decorator only supports keyword arguments" 334 335 assert ( 336 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 337 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 338 339 def decorator( 340 func: Callable[[AttentionMatrix], Matrix2D], 341 ) -> AttentionMatrixFigureFunc: 342 @functools.wraps(func) 343 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 344 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 345 processed_matrix: Matrix2D = func(attn_matrix) 346 347 if fmt == "png": 348 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 349 processed_matrix, 350 normalize=normalize, 351 cmap=cmap, 352 ) 353 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 354 fig_path.write_bytes(image_data) 355 356 else: 357 svg_content: str = matrix_as_svg( 358 processed_matrix, normalize=normalize, cmap=cmap 359 ) 360 361 if fmt == "svgz": 362 with gzip.open(fig_path, "wt") as f: 363 f.write(svg_content) 364 365 else: 366 fig_path.write_text(svg_content, encoding="utf-8") 367 368 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 369 370 return wrapped 371 372 if callable(func): 373 # Handle no-arguments case 374 return decorator(func) 375 else: 376 # Handle arguments case 377 return decorator
Decorator for functions that process an attention matrix and save it as an SVGZ image. Can handle both argumentless usage and with arguments.
Parameters:
func : AttentionMatrixToMatrixFunc|NoneEither the function to decorate (in the no-arguments case) orNonewhen used with arguments.fmt : MatrixSaveFormat, keyword-onlyThe format to save the matrix as. Defaults toMATRIX_SAVE_FMT.normalize : bool, keyword-onlyWhether to normalize the matrix to range [0, 1]. Defaults toFalse.cmap : str, keyword-onlyThe colormap to use for the matrix. Defaults toMATRIX_SVG_CMAP.
Returns:
AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]
AttentionMatrixFigureFunciffuncisAttentionMatrixToMatrixFunc(no arguments case)Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]iffuncisNone-- returns the decorator which will then be applied to the (with arguments case)
Usage:
@save_matrix_wrapper
def identity_matrix(matrix):
return matrix
@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
return matrix * 2
@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
return matrix * 2