Coverage for tests\unit\test_figure_util.py: 100%
176 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1import jaxtyping
2import pytest
3import numpy as np
4from pathlib import Path
5import gzip
6import re
7import base64
8import io
10from PIL import Image
11import matplotlib.pyplot as plt
14from pattern_lens.figure_util import (
15 MATPLOTLIB_FIGURE_FMT,
16 matplotlib_figure_saver,
17 matrix_as_svg,
18 save_matrix_wrapper,
19)
22TEMP_DIR: Path = Path("tests/_temp")
25def test_matplotlib_figure_saver():
26 TEMP_DIR.mkdir(parents=True, exist_ok=True)
28 @matplotlib_figure_saver
29 def plot_matrix(attn_matrix, ax):
30 ax.matshow(attn_matrix, cmap="viridis")
31 ax.axis("off")
33 attn_matrix = np.random.rand(10, 10).astype(np.float32)
34 plot_matrix(attn_matrix, TEMP_DIR)
36 saved_file = TEMP_DIR / f"plot_matrix.{MATPLOTLIB_FIGURE_FMT}"
37 assert saved_file.exists(), "Matplotlib figure file was not saved"
40def test_matplotlib_figure_saver_exception():
41 TEMP_DIR.mkdir(parents=True, exist_ok=True)
43 @matplotlib_figure_saver
44 def faulty_plot(attn_matrix, ax):
45 raise ValueError("Intentional failure for testing")
47 attn_matrix = np.random.rand(10, 10).astype(np.float32)
48 with pytest.raises(ValueError, match="Intentional failure for testing"):
49 faulty_plot(attn_matrix, TEMP_DIR)
52def test_matrix_as_svg_normalization():
53 matrix = np.array([[2, 4], [6, 8]], dtype=np.float32)
54 svg_content = matrix_as_svg(matrix, normalize=True)
55 assert "image href=" in svg_content, "SVG content is malformed"
56 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing"
59def test_matrix_as_svg_no_normalization():
60 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32)
61 svg_content = matrix_as_svg(matrix, normalize=False)
62 assert "image href=" in svg_content, "SVG content is malformed"
63 assert "data:image/png;base64," in svg_content, "Base64 encoding is missing"
66def test_matrix_as_svg_invalid_range():
67 matrix = np.array([[-1, 2], [3, 4]], dtype=np.float32)
68 with pytest.raises(
69 AssertionError,
70 match="Matrix values must be in range \\[0, 1\\], or normalize must be True",
71 ):
72 matrix_as_svg(matrix, normalize=False)
75def test_matrix_as_svg_invalid_dims():
76 matrix = np.random.rand(5, 5, 5).astype(np.float32)
77 with pytest.raises((AssertionError, jaxtyping.TypeCheckError)):
78 matrix_as_svg(matrix, normalize=True)
81def test_matrix_as_svg_invalid_cmap_fixed():
82 matrix = np.array([[0.1, 0.4], [0.6, 0.9]], dtype=np.float32)
83 with pytest.raises(KeyError, match="'invalid_cmap' is not a known colormap name"):
84 matrix_as_svg(matrix, cmap="invalid_cmap")
87# Test with no arguments
88def test_save_matrix_as_svgz_wrapper_no_args():
89 TEMP_DIR.mkdir(parents=True, exist_ok=True)
91 @save_matrix_wrapper(fmt="svgz")
92 def no_op(matrix):
93 return matrix
95 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
96 no_op(test_matrix, TEMP_DIR)
98 saved_file = TEMP_DIR / "no_op.svgz"
99 assert saved_file.exists(), "SVGZ file was not saved in no-args case"
102# Test with keyword-only arguments
103def test_save_matrix_as_svgz_wrapper_with_args():
104 TEMP_DIR.mkdir(parents=True, exist_ok=True)
106 @save_matrix_wrapper(normalize=True, cmap="plasma")
107 def scale_matrix(matrix):
108 return matrix * 2
110 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32)
111 scale_matrix(test_matrix, TEMP_DIR)
113 saved_file = TEMP_DIR / "scale_matrix.svgz"
114 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments"
117# Test exception handling
118def test_save_matrix_as_svgz_wrapper_exceptions():
119 TEMP_DIR.mkdir(parents=True, exist_ok=True)
121 @save_matrix_wrapper(normalize=False)
122 def invalid_range(matrix):
123 return matrix * 2
125 test_matrix = np.array([[2, 3], [4, 5]], dtype=np.float32)
126 with pytest.raises(
127 AssertionError,
128 match=r"Matrix values must be in range \[0, 1\], or normalize must be True\. got: min: .*?, max: .*?",
129 ):
130 invalid_range(test_matrix, TEMP_DIR)
133# Test keyword-only arguments enforced
134def test_save_matrix_as_svgz_wrapper_keyword_only():
135 TEMP_DIR.mkdir(parents=True, exist_ok=True)
137 @save_matrix_wrapper(normalize=True, cmap="plasma")
138 def scale_matrix(matrix):
139 return matrix * 2
141 test_matrix = np.array([[0.5, 0.6], [0.7, 0.8]], dtype=np.float32)
142 scale_matrix(test_matrix, TEMP_DIR)
144 saved_file = TEMP_DIR / "scale_matrix.svgz"
145 assert saved_file.exists(), "SVGZ file was not saved with keyword-only arguments"
148# Test multiple calls to the decorator
149def test_save_matrix_as_svgz_wrapper_multiple():
150 TEMP_DIR.mkdir(parents=True, exist_ok=True)
152 @save_matrix_wrapper(normalize=True)
153 def scale_by_factor(matrix):
154 return matrix * 3
156 matrix_1 = np.array([[0.1, 0.5], [0.7, 0.9]], dtype=np.float32)
157 matrix_2 = np.array([[0.2, 0.6], [0.8, 1.0]], dtype=np.float32)
159 scale_by_factor(matrix_1, TEMP_DIR)
160 scale_by_factor(matrix_2, TEMP_DIR)
162 # Check the saved files
163 saved_file = TEMP_DIR / "scale_by_factor.svgz"
164 assert saved_file.exists(), "SVGZ file was not saved for multiple calls"
167# Validate behavior when normalize is False and values are in range
168def test_save_matrix_as_svgz_wrapper_no_normalization():
169 TEMP_DIR.mkdir(parents=True, exist_ok=True)
171 @save_matrix_wrapper(normalize=False)
172 def pass_through(matrix):
173 return matrix
175 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
176 pass_through(test_matrix, TEMP_DIR)
178 saved_file = TEMP_DIR / "pass_through.svgz"
179 assert (
180 saved_file.exists()
181 ), "SVGZ file was not saved when normalization was not applied"
184# Test with a complex matrix
185def test_save_matrix_as_svgz_wrapper_complex_matrix():
186 TEMP_DIR.mkdir(parents=True, exist_ok=True)
188 @save_matrix_wrapper(normalize=True, cmap="viridis")
189 def complex_processing(matrix):
190 return np.sin(matrix)
192 test_matrix = np.linspace(0, np.pi, 16).reshape(4, 4).astype(np.float32)
193 complex_processing(test_matrix, TEMP_DIR)
195 saved_file = TEMP_DIR / "complex_processing.svgz"
196 assert saved_file.exists(), "SVGZ file was not saved for complex matrix processing"
199def test_matrix_as_svg_dimensions():
200 # Test different matrix shapes
201 matrices = [
202 np.random.rand(5, 10), # Non-square
203 np.random.rand(3, 3), # Small square
204 np.random.rand(100, 50), # Large non-square
205 ]
207 for matrix in matrices:
208 m, n = matrix.shape
209 svg_content = matrix_as_svg(matrix, normalize=True)
210 assert f'width="{m}"' in svg_content
211 assert f'height="{n}"' in svg_content
212 assert f'viewBox="0 0 {m} {n}"' in svg_content
215def test_save_matrix_as_svgz_wrapper_content():
216 TEMP_DIR.mkdir(parents=True, exist_ok=True)
218 @save_matrix_wrapper(normalize=True)
219 def identity(matrix):
220 return matrix
222 test_matrix = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)
223 identity(test_matrix, TEMP_DIR)
225 saved_file = TEMP_DIR / "identity.svgz"
226 with gzip.open(saved_file, "rt") as f:
227 content = f.read()
228 assert "svg" in content
229 assert "image href=" in content
230 assert "base64" in content
233def test_matplotlib_figure_saver_formats():
234 TEMP_DIR.mkdir(parents=True, exist_ok=True)
235 formats = ["png", "pdf", "svg"]
237 for fmt in formats:
239 @matplotlib_figure_saver(fmt=fmt)
240 def plot_matrix(attn_matrix, ax):
241 ax.matshow(attn_matrix)
242 ax.axis("off")
244 matrix = np.random.rand(5, 5)
245 plot_matrix(matrix, TEMP_DIR)
246 saved_file = TEMP_DIR / f"plot_matrix.{fmt}"
247 assert saved_file.exists(), f"File not saved for format {fmt}"
250def test_matrix_as_svg_empty():
251 empty_matrix = np.array([[]], dtype=np.float32).reshape(0, 0)
252 with pytest.raises(AssertionError, match="Matrix cannot be empty"):
253 matrix_as_svg(empty_matrix)
256def test_matplotlib_figure_saver_cleanup():
257 TEMP_DIR.mkdir(parents=True, exist_ok=True)
258 initial_figures = len(plt.get_fignums())
260 @matplotlib_figure_saver
261 def plot_matrix(attn_matrix, ax):
262 ax.matshow(attn_matrix)
264 matrix = np.random.rand(5, 5)
265 plot_matrix(matrix, TEMP_DIR)
267 # Check that no figure objects remain
268 assert len(plt.get_fignums()) == initial_figures, "Figure not properly cleaned up"
271def test_matrix_as_svg_non_numeric():
272 matrix = np.array([["a", "b"], ["c", "d"]])
273 with pytest.raises(TypeError):
274 matrix_as_svg(matrix)
277def test_matrix_as_svg_format():
278 # create a small 2x2 matrix
279 matrix = np.array([[0.0, 0.5], [1.0, 0.75]], dtype=float)
281 svg_str = matrix_as_svg(matrix)
283 # ensure it's got the correct SVG wrapper
284 assert svg_str.startswith("<svg"), "SVG should start with <svg>"
285 assert svg_str.endswith("</svg>"), "SVG should end with </svg>"
287 # find the embedded base64 image data
288 match = re.search(r'data:image/png;base64,([^"]+)', svg_str)
289 assert match, "Expected an embedded PNG in data URI format"
291 embedded_data = match.group(1)
292 png_data = base64.b64decode(embedded_data)
294 Image.open(io.BytesIO(png_data))