Coverage for tests\unit\test_figure_util.py: 100%

176 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1import jaxtyping 

2import pytest 

3import numpy as np 

4from pathlib import Path 

5import gzip 

6import re 

7import base64 

8import io 

9 

10from PIL import Image 

11import matplotlib.pyplot as plt 

12 

13 

14from pattern_lens.figure_util import ( 

15 MATPLOTLIB_FIGURE_FMT, 

16 matplotlib_figure_saver, 

17 matrix_as_svg, 

18 save_matrix_wrapper, 

19) 

20 

21 

22TEMP_DIR: Path = Path("tests/_temp") 

23 

24 

25def test_matplotlib_figure_saver(): 

26 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

27 

28 @matplotlib_figure_saver 

29 def plot_matrix(attn_matrix, ax): 

30 ax.matshow(attn_matrix, cmap="viridis") 

31 ax.axis("off") 

32 

33 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

34 plot_matrix(attn_matrix, TEMP_DIR) 

35 

36 saved_file = TEMP_DIR / f"plot_matrix.{MATPLOTLIB_FIGURE_FMT}" 

37 assert saved_file.exists(), "Matplotlib figure file was not saved" 

38 

39 

40def test_matplotlib_figure_saver_exception(): 

41 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

42 

43 @matplotlib_figure_saver 

44 def faulty_plot(attn_matrix, ax): 

45 raise ValueError("Intentional failure for testing") 

46 

47 attn_matrix = np.random.rand(10, 10).astype(np.float32) 

48 with pytest.raises(ValueError, match="Intentional failure for testing"): 

49 faulty_plot(attn_matrix, TEMP_DIR) 

50 

51 

52def test_matrix_as_svg_normalization(): 

53 matrix = np.array([[2, 4], [6, 8]], dtype=np.float32) 

54 svg_content = matrix_as_svg(matrix, normalize=True) 

55 assert "image href=" in svg_content, "SVG content is malformed" 

56 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing" 

57 

58 

59def test_matrix_as_svg_no_normalization(): 

60 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32) 

61 svg_content = matrix_as_svg(matrix, normalize=False) 

62 assert "image href=" in svg_content, "SVG content is malformed" 

63 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing" 

64 

65 

66def test_matrix_as_svg_invalid_range(): 

67 matrix = np.array([[-1, 2], [3, 4]], dtype=np.float32) 

68 with pytest.raises( 

69 AssertionError, 

70 match="Matrix values must be in range \\[0, 1\\], or normalize must be True", 

71 ): 

72 matrix_as_svg(matrix, normalize=False) 

73 

74 

75def test_matrix_as_svg_invalid_dims(): 

76 matrix = np.random.rand(5, 5, 5).astype(np.float32) 

77 with pytest.raises((AssertionError, jaxtyping.TypeCheckError)): 

78 matrix_as_svg(matrix, normalize=True) 

79 

80 

81def test_matrix_as_svg_invalid_cmap_fixed(): 

82 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32) 

83 with pytest.raises(KeyError, match="'invalid_cmap' is not a known colormap name"): 

84 matrix_as_svg(matrix, cmap="invalid_cmap") 

85 

86 

87# Test with no arguments 

88def test_save_matrix_as_svgz_wrapper_no_args(): 

89 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

90 

91 @save_matrix_wrapper(fmt="svgz") 

92 def no_op(matrix): 

93 return matrix 

94 

95 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

96 no_op(test_matrix, TEMP_DIR) 

97 

98 saved_file = TEMP_DIR / "no_op.svgz" 

99 assert saved_file.exists(), "SVGZ file was not saved in no-args case" 

100 

101 

102# Test with keyword-only arguments 

103def test_save_matrix_as_svgz_wrapper_with_args(): 

104 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

105 

106 @save_matrix_wrapper(normalize=True, cmap="plasma") 

107 def scale_matrix(matrix): 

108 return matrix * 2 

109 

110 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32) 

111 scale_matrix(test_matrix, TEMP_DIR) 

112 

113 saved_file = TEMP_DIR / "scale_matrix.svgz" 

114 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments" 

115 

116 

117# Test exception handling 

118def test_save_matrix_as_svgz_wrapper_exceptions(): 

119 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

120 

121 @save_matrix_wrapper(normalize=False) 

122 def invalid_range(matrix): 

123 return matrix * 2 

124 

125 test_matrix = np.array([[2, 3], [4, 5]], dtype=np.float32) 

126 with pytest.raises( 

127 AssertionError, 

128 match=r"Matrix values must be in range \[0, 1\], or normalize must be True\. got: min: .*?, max: .*?", 

129 ): 

130 invalid_range(test_matrix, TEMP_DIR) 

131 

132 

133# Test keyword-only arguments enforced 

134def test_save_matrix_as_svgz_wrapper_keyword_only(): 

135 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

136 

137 @save_matrix_wrapper(normalize=True, cmap="plasma") 

138 def scale_matrix(matrix): 

139 return matrix * 2 

140 

141 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32) 

142 scale_matrix(test_matrix, TEMP_DIR) 

143 

144 saved_file = TEMP_DIR / "scale_matrix.svgz" 

145 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments" 

146 

147 

148# Test multiple calls to the decorator 

149def test_save_matrix_as_svgz_wrapper_multiple(): 

150 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

151 

152 @save_matrix_wrapper(normalize=True) 

153 def scale_by_factor(matrix): 

154 return matrix * 3 

155 

156 matrix_1 = np.array([[0.1, 0.5], [0.7, 0.9]], dtype=np.float32) 

157 matrix_2 = np.array([[0.2, 0.6], [0.8, 1.0]], dtype=np.float32) 

158 

159 scale_by_factor(matrix_1, TEMP_DIR) 

160 scale_by_factor(matrix_2, TEMP_DIR) 

161 

162 # Check the saved files 

163 saved_file = TEMP_DIR / "scale_by_factor.svgz" 

164 assert saved_file.exists(), "SVGZ file was not saved for multiple calls" 

165 

166 

167# Validate behavior when normalize is False and values are in range 

168def test_save_matrix_as_svgz_wrapper_no_normalization(): 

169 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

170 

171 @save_matrix_wrapper(normalize=False) 

172 def pass_through(matrix): 

173 return matrix 

174 

175 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

176 pass_through(test_matrix, TEMP_DIR) 

177 

178 saved_file = TEMP_DIR / "pass_through.svgz" 

179 assert ( 

180 saved_file.exists() 

181 ), "SVGZ file was not saved when normalization was not applied" 

182 

183 

184# Test with a complex matrix 

185def test_save_matrix_as_svgz_wrapper_complex_matrix(): 

186 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

187 

188 @save_matrix_wrapper(normalize=True, cmap="viridis") 

189 def complex_processing(matrix): 

190 return np.sin(matrix) 

191 

192 test_matrix = np.linspace(0, np.pi, 16).reshape(4, 4).astype(np.float32) 

193 complex_processing(test_matrix, TEMP_DIR) 

194 

195 saved_file = TEMP_DIR / "complex_processing.svgz" 

196 assert saved_file.exists(), "SVGZ file was not saved for complex matrix processing" 

197 

198 

199def test_matrix_as_svg_dimensions(): 

200 # Test different matrix shapes 

201 matrices = [ 

202 np.random.rand(5, 10), # Non-square 

203 np.random.rand(3, 3), # Small square 

204 np.random.rand(100, 50), # Large non-square 

205 ] 

206 

207 for matrix in matrices: 

208 m, n = matrix.shape 

209 svg_content = matrix_as_svg(matrix, normalize=True) 

210 assert f'width="{m}"' in svg_content 

211 assert f'height="{n}"' in svg_content 

212 assert f'viewBox="0 0 {m} {n}"' in svg_content 

213 

214 

215def test_save_matrix_as_svgz_wrapper_content(): 

216 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

217 

218 @save_matrix_wrapper(normalize=True) 

219 def identity(matrix): 

220 return matrix 

221 

222 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32) 

223 identity(test_matrix, TEMP_DIR) 

224 

225 saved_file = TEMP_DIR / "identity.svgz" 

226 with gzip.open(saved_file, "rt") as f: 

227 content = f.read() 

228 assert "svg" in content 

229 assert "image href=" in content 

230 assert "base64" in content 

231 

232 

233def test_matplotlib_figure_saver_formats(): 

234 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

235 formats = ["png", "pdf", "svg"] 

236 

237 for fmt in formats: 

238 

239 @matplotlib_figure_saver(fmt=fmt) 

240 def plot_matrix(attn_matrix, ax): 

241 ax.matshow(attn_matrix) 

242 ax.axis("off") 

243 

244 matrix = np.random.rand(5, 5) 

245 plot_matrix(matrix, TEMP_DIR) 

246 saved_file = TEMP_DIR / f"plot_matrix.{fmt}" 

247 assert saved_file.exists(), f"File not saved for format {fmt}" 

248 

249 

250def test_matrix_as_svg_empty(): 

251 empty_matrix = np.array([[]], dtype=np.float32).reshape(0, 0) 

252 with pytest.raises(AssertionError, match="Matrix cannot be empty"): 

253 matrix_as_svg(empty_matrix) 

254 

255 

256def test_matplotlib_figure_saver_cleanup(): 

257 TEMP_DIR.mkdir(parents=True, exist_ok=True) 

258 initial_figures = len(plt.get_fignums()) 

259 

260 @matplotlib_figure_saver 

261 def plot_matrix(attn_matrix, ax): 

262 ax.matshow(attn_matrix) 

263 

264 matrix = np.random.rand(5, 5) 

265 plot_matrix(matrix, TEMP_DIR) 

266 

267 # Check that no figure objects remain 

268 assert len(plt.get_fignums()) == initial_figures, "Figure not properly cleaned up" 

269 

270 

271def test_matrix_as_svg_non_numeric(): 

272 matrix = np.array([["a", "b"], ["c", "d"]]) 

273 with pytest.raises(TypeError): 

274 matrix_as_svg(matrix) 

275 

276 

277def test_matrix_as_svg_format(): 

278 # create a small 2x2 matrix 

279 matrix = np.array([[0.0, 0.5], [1.0, 0.75]], dtype=float) 

280 

281 svg_str = matrix_as_svg(matrix) 

282 

283 # ensure it's got the correct SVG wrapper 

284 assert svg_str.startswith("<svg"), "SVG should start with <svg>" 

285 assert svg_str.endswith("</svg>"), "SVG should end with </svg>" 

286 

287 # find the embedded base64 image data 

288 match = re.search(r'data:image/png;base64,([^"]+)', svg_str) 

289 assert match, "Expected an embedded PNG in data URI format" 

290 

291 embedded_data = match.group(1) 

292 png_data = base64.b64decode(embedded_data) 

293 

294 Image.open(io.BytesIO(png_data))