Coverage for pattern_lens\figure_util.py: 94%

115 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1"""implements a bunch of types, default values, and templates which are useful for figure functions 

2 

3notably, you can use the decorators `matplotlib_figure_saver`, `save_matrix_wrapper` to make your functions save figures 

4""" 

5 

6from pathlib import Path 

7from typing import Callable, Literal, overload, Union 

8import functools 

9import base64 

10import gzip 

11import io 

12 

13from PIL import Image 

14import numpy as np 

15from jaxtyping import Float, UInt8 

16import matplotlib 

17import matplotlib.pyplot as plt 

18from matplotlib.colors import Colormap 

19 

20from pattern_lens.consts import AttentionMatrix 

21 

22AttentionMatrixFigureFunc = Callable[[AttentionMatrix, Path], None] 

23"Type alias for a function that, given an attention matrix, saves a figure" 

24 

25Matrix2D = Float[np.ndarray, "n m"] 

26"Type alias for a 2D matrix (plottable)" 

27 

28Matrix2Drgb = UInt8[np.ndarray, "n m rgb=3"] 

29"Type alias for a 2D matrix with 3 channels (RGB)" 

30 

31AttentionMatrixToMatrixFunc = Callable[[AttentionMatrix], Matrix2D] 

32"Type alias for a function that, given an attention matrix, returns a 2D matrix" 

33 

34MATPLOTLIB_FIGURE_FMT: str = "svgz" 

35"format for saving matplotlib figures" 

36 

37MatrixSaveFormat = Literal["png", "svg", "svgz"] 

38"Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure" 

39 

40MATRIX_SAVE_NORMALIZE: bool = False 

41"default for whether to normalize the matrix to range [0, 1]" 

42 

43MATRIX_SAVE_CMAP: str = "viridis" 

44"default colormap for saving matrices" 

45 

46MATRIX_SAVE_FMT: MatrixSaveFormat = "svgz" 

47"default format for saving matrices" 

48 

49MATRIX_SAVE_SVG_TEMPLATE: str = """<svg xmlns="http://www.w3.org/2000/svg" width="{m}" height="{n}" viewBox="0 0 {m} {n}" image-rendering="pixelated"> <image href="data:image/png;base64,{png_base64}" width="{m}" height="{n}" /> </svg>""" 

50"template for saving an `n` by `m` matrix as an svg/svgz" 

51 

52 

53@overload # without keyword arguments, returns decorated function 

54def matplotlib_figure_saver( 

55 func: Callable[[AttentionMatrix, plt.Axes], None], 

56 *args, 

57 fmt: str = MATPLOTLIB_FIGURE_FMT, 

58) -> AttentionMatrixFigureFunc: ... 

59@overload # with keyword arguments, returns decorator 

60def matplotlib_figure_saver( 

61 func: None = None, 

62 *args, 

63 fmt: str = MATPLOTLIB_FIGURE_FMT, 

64) -> Callable[ 

65 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 

66]: ... 

67def matplotlib_figure_saver( 

68 func: Callable[[AttentionMatrix, plt.Axes], None] | None = None, 

69 *args, 

70 fmt: str = MATPLOTLIB_FIGURE_FMT, 

71) -> Union[ 

72 AttentionMatrixFigureFunc, 

73 Callable[ 

74 [Callable[[AttentionMatrix, plt.Axes], None], str], AttentionMatrixFigureFunc 

75 ], 

76]: 

77 """decorator for functions which take an attention matrix and predefined `ax` object, making it save a figure 

78 

79 # Parameters: 

80 - `func : Callable[[AttentionMatrix, plt.Axes], None]` 

81 your function, which should take an attention matrix and predefined `ax` object 

82 - `fmt : str` 

83 format for saving matplotlib figures 

84 (defaults to `MATPLOTLIB_FIGURE_FMT`) 

85 

86 # Returns: 

87 - `AttentionMatrixFigureFunc` 

88 your function, after we wrap it to save a figure 

89 

90 # Usage: 

91 ```python 

92 @register_attn_figure_func 

93 @matplotlib_figure_saver 

94 def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

95 ax.matshow(attn_matrix, cmap="viridis") 

96 ax.set_title("Raw Attention Pattern") 

97 ax.axis("off") 

98 ``` 

99 

100 """ 

101 

102 assert len(args) == 0, "This decorator only supports keyword arguments" 

103 

104 def decorator( 

105 func: Callable[[AttentionMatrix, plt.Axes], None], 

106 fmt: str = fmt, 

107 ) -> AttentionMatrixFigureFunc: 

108 @functools.wraps(func) 

109 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

110 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

111 

112 fig, ax = plt.subplots(figsize=(10, 10)) 

113 func(attn_matrix, ax) 

114 plt.tight_layout() 

115 plt.savefig(fig_path) 

116 plt.close(fig) 

117 

118 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

119 

120 return wrapped 

121 

122 if callable(func): 

123 # Handle no-arguments case 

124 return decorator(func) 

125 else: 

126 # Handle arguments case 

127 return decorator 

128 

129 

130def matrix_to_image_preprocess( 

131 matrix: Matrix2D, 

132 normalize: bool = MATRIX_SAVE_NORMALIZE, 

133 cmap: str | Colormap = MATRIX_SAVE_CMAP, 

134) -> Matrix2Drgb: 

135 """preprocess a 2D matrix into a plottable heatmap image 

136 

137 # Parameters: 

138 - `matrix : Matrix2D` 

139 input matrix 

140 - `normalize : bool` 

141 whether to normalize the matrix to range [0, 1] 

142 (defaults to `MATRIX_SAVE_NORMALIZE`) 

143 - `cmap : str|Colormap` 

144 the colormap to use for the matrix 

145 (defaults to `MATRIX_SAVE_CMAP`) 

146 

147 # Returns: 

148 - `Matrix2Drgb` 

149 """ 

150 # check dims 

151 assert matrix.ndim == 2, f"Matrix must be 2D, got {matrix.ndim = }" 

152 

153 # check matrix is not empty 

154 assert matrix.size > 0, "Matrix cannot be empty" 

155 

156 # Normalize the matrix to range [0, 1] 

157 normalized_matrix: Matrix2D 

158 if normalize: 

159 max_val, min_val = matrix.max(), matrix.min() 

160 normalized_matrix = (matrix - min_val) / (max_val - min_val) 

161 else: 

162 assert ( 

163 matrix.min() >= 0 and matrix.max() <= 1 

164 ), "Matrix values must be in range [0, 1], or normalize must be True. got: min: {matrix.min() = }, max: {matrix.max() = }" 

165 normalized_matrix = matrix 

166 

167 # get the colormap 

168 cmap_: Colormap 

169 if isinstance(cmap, str): 

170 cmap_ = matplotlib.colormaps[cmap] 

171 elif isinstance(cmap, Colormap): 

172 cmap_ = cmap 

173 else: 

174 raise TypeError( 

175 f"Invalid type for {cmap = }, {type(cmap) = }, must be str or Colormap" 

176 ) 

177 

178 # Apply the viridis colormap 

179 rgb_matrix: Float[np.ndarray, "n m channels=3"] = ( # noqa: F722 

180 cmap_(normalized_matrix)[:, :, :3] * 255 

181 ).astype(np.uint8) # Drop alpha channel 

182 

183 assert rgb_matrix.shape == ( 

184 matrix.shape[0], 

185 matrix.shape[1], 

186 3, 

187 ), f"Matrix after colormap must have 3 channels, got {rgb_matrix.shape = }" 

188 

189 return rgb_matrix 

190 

191 

192@overload 

193def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: None = None) -> bytes: ... 

194@overload 

195def matrix2drgb_to_png_bytes(matrix: Matrix2Drgb, buffer: io.BytesIO) -> None: ... 

196def matrix2drgb_to_png_bytes( 

197 matrix: Matrix2Drgb, buffer: io.BytesIO | None = None 

198) -> bytes | None: 

199 """Convert a `Matrix2Drgb` to valid PNG bytes via PIL 

200 

201 - if `buffer` is provided, it will write the PNG bytes to the buffer and return `None` 

202 - if `buffer` is not provided, it will return the PNG bytes 

203 

204 # Parameters: 

205 - `matrix : Matrix2Drgb` 

206 - `buffer : io.BytesIO | None` 

207 (defaults to `None`, in which case it will return the PNG bytes) 

208 

209 # Returns: 

210 - `bytes|None` 

211 `bytes` if `buffer` is `None`, otherwise `None` 

212 """ 

213 

214 pil_img: Image.Image = Image.fromarray(matrix, mode="RGB") 

215 if buffer is None: 

216 buffer = io.BytesIO() 

217 pil_img.save(buffer, format="PNG") 

218 buffer.seek(0) 

219 return buffer.read() 

220 else: 

221 pil_img.save(buffer, format="PNG") 

222 return None 

223 

224 

225def matrix_as_svg( 

226 matrix: Matrix2D, 

227 normalize: bool = MATRIX_SAVE_NORMALIZE, 

228 cmap=MATRIX_SAVE_CMAP, 

229) -> str: 

230 """quickly convert a 2D matrix to an SVG image, without matplotlib 

231 

232 # Parameters: 

233 - `matrix : Float[np.ndarray, 'n m']` 

234 a 2D matrix to convert to an SVG image 

235 - `normalize : bool` 

236 whether to normalize the matrix to range [0, 1]. if it's not in the range [0, 1], this must be `True` or it will raise an `AssertionError` 

237 (defaults to `False`) 

238 - `cmap : str` 

239 the colormap to use for the matrix -- will look up in `matplotlib.colormaps` if it's a string 

240 (defaults to `"viridis"`) 

241 

242 # Returns: 

243 - `str` 

244 the SVG content for the matrix 

245 """ 

246 # Get the dimensions of the matrix 

247 m, n = matrix.shape 

248 

249 # Preprocess the matrix into an RGB image 

250 matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

251 matrix, normalize=normalize, cmap=cmap 

252 ) 

253 

254 # Convert the RGB image to PNG bytes 

255 image_data: bytes = matrix2drgb_to_png_bytes(matrix_rgb) 

256 

257 # Encode the PNG bytes as base64 

258 png_base64: str = base64.b64encode(image_data).decode("utf-8") 

259 

260 # Generate the SVG content 

261 svg_content: str = MATRIX_SAVE_SVG_TEMPLATE.format(m=m, n=n, png_base64=png_base64) 

262 

263 return svg_content 

264 

265 

266@overload # with keyword arguments, returns decorator 

267def save_matrix_wrapper( 

268 func: None = None, 

269 *args, 

270 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

271 normalize: bool = MATRIX_SAVE_NORMALIZE, 

272 cmap: str = MATRIX_SAVE_CMAP, 

273) -> Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]: ... 

274@overload # without keyword arguments, returns decorated function 

275def save_matrix_wrapper( 

276 func: AttentionMatrixToMatrixFunc, 

277 *args, 

278 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

279 normalize: bool = MATRIX_SAVE_NORMALIZE, 

280 cmap: str = MATRIX_SAVE_CMAP, 

281) -> AttentionMatrixFigureFunc: ... 

282def save_matrix_wrapper( 

283 func: AttentionMatrixToMatrixFunc | None = None, 

284 *args, 

285 fmt: MatrixSaveFormat = MATRIX_SAVE_FMT, 

286 normalize: bool = MATRIX_SAVE_NORMALIZE, 

287 cmap=MATRIX_SAVE_CMAP, 

288) -> ( 

289 AttentionMatrixFigureFunc 

290 | Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc] 

291): 

292 """ 

293 Decorator for functions that process an attention matrix and save it as an SVGZ image. 

294 Can handle both argumentless usage and with arguments. 

295 

296 # Parameters: 

297 

298 - `func : AttentionMatrixToMatrixFunc|None` 

299 Either the function to decorate (in the no-arguments case) or `None` when used with arguments. 

300 - `fmt : MatrixSaveFormat, keyword-only` 

301 The format to save the matrix as. Defaults to `MATRIX_SAVE_FMT`. 

302 - `normalize : bool, keyword-only` 

303 Whether to normalize the matrix to range [0, 1]. Defaults to `False`. 

304 - `cmap : str, keyword-only` 

305 The colormap to use for the matrix. Defaults to `MATRIX_SVG_CMAP`. 

306 

307 # Returns: 

308 

309 `AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` 

310 

311 - `AttentionMatrixFigureFunc` if `func` is `AttentionMatrixToMatrixFunc` (no arguments case) 

312 - `Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]` if `func` is `None` -- returns the decorator which will then be applied to the (with arguments case) 

313 

314 # Usage: 

315 

316 ```python 

317 @save_matrix_wrapper 

318 def identity_matrix(matrix): 

319 return matrix 

320 

321 @save_matrix_wrapper(normalize=True, fmt="png") 

322 def scale_matrix(matrix): 

323 return matrix * 2 

324 

325 @save_matrix_wrapper(normalize=True, cmap="plasma") 

326 def scale_matrix(matrix): 

327 return matrix * 2 

328 

329 ``` 

330 """ 

331 

332 assert len(args) == 0, "This decorator only supports keyword arguments" 

333 

334 assert ( 

335 fmt in MatrixSaveFormat.__args__ # type: ignore[attr-defined] 

336 ), f"Invalid format {fmt = }, must be one of {MatrixSaveFormat.__args__}" # type: ignore[attr-defined] 

337 

338 def decorator( 

339 func: Callable[[AttentionMatrix], Matrix2D], 

340 ) -> AttentionMatrixFigureFunc: 

341 @functools.wraps(func) 

342 def wrapped(attn_matrix: AttentionMatrix, save_dir: Path) -> None: 

343 fig_path: Path = save_dir / f"{func.__name__}.{fmt}" 

344 processed_matrix: Matrix2D = func(attn_matrix) 

345 

346 if fmt == "png": 

347 processed_matrix_rgb: Matrix2Drgb = matrix_to_image_preprocess( 

348 processed_matrix, 

349 normalize=normalize, 

350 cmap=cmap, 

351 ) 

352 image_data: bytes = matrix2drgb_to_png_bytes(processed_matrix_rgb) 

353 fig_path.write_bytes(image_data) 

354 

355 else: 

356 svg_content: str = matrix_as_svg( 

357 processed_matrix, normalize=normalize, cmap=cmap 

358 ) 

359 

360 if fmt == "svgz": 

361 with gzip.open(fig_path, "wt") as f: 

362 f.write(svg_content) 

363 

364 else: 

365 fig_path.write_text(svg_content, encoding="utf-8") 

366 

367 wrapped.figure_save_fmt = fmt # type: ignore[attr-defined] 

368 

369 return wrapped 

370 

371 if callable(func): 

372 # Handle no-arguments case 

373 return decorator(func) 

374 else: 

375 # Handle arguments case 

376 return decorator