Coverage for pattern_lens\attn_figure_funcs.py: 100%

10 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1"""default figure functions 

2 

3- If you are making a PR, add your new figure function here. 

4- if you are using this as a library, then you can see examples here 

5 

6 

7note that for `pattern_lens.figures` to recognize your function, you need to use the `register_attn_figure_func` decorator 

8which adds your function to `ATTENTION_MATRIX_FIGURE_FUNCS` 

9 

10""" 

11 

12from pattern_lens.consts import AttentionMatrix 

13from pattern_lens.figure_util import ( 

14 AttentionMatrixFigureFunc, 

15 save_matrix_wrapper, 

16 Matrix2D, 

17) 

18 

19 

20ATTENTION_MATRIX_FIGURE_FUNCS: list[AttentionMatrixFigureFunc] = list() 

21 

22 

23def register_attn_figure_func( 

24 func: AttentionMatrixFigureFunc, 

25) -> AttentionMatrixFigureFunc: 

26 """decorator for registering attention matrix figure function 

27 

28 if you want to add a new figure function, you should use this decorator 

29 

30 # Parameters: 

31 - `func : AttentionMatrixFigureFunc` 

32 your function, which should take an attention matrix and path 

33 

34 # Returns: 

35 - `AttentionMatrixFigureFunc` 

36 your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS` 

37 

38 # Usage: 

39 ```python 

40 @register_attn_figure_func 

41 def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None: 

42 fig, ax = plt.subplots(figsize=(10, 10)) 

43 ax.matshow(attn_matrix, cmap="viridis") 

44 ax.set_title("My New Figure Function") 

45 ax.axis("off") 

46 plt.savefig(path / "my_new_figure_func", format="svgz") 

47 plt.close(fig) 

48 ``` 

49 

50 """ 

51 global ATTENTION_MATRIX_FIGURE_FUNCS 

52 

53 ATTENTION_MATRIX_FIGURE_FUNCS.append(func) 

54 

55 return func 

56 

57 

58@register_attn_figure_func 

59@save_matrix_wrapper(fmt="png") 

60def raw(attn_matrix: AttentionMatrix) -> Matrix2D: 

61 return attn_matrix 

62 

63 

64# some more examples: 

65 

66# @register_attn_figure_func 

67# @matplotlib_figure_saver 

68# def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None: 

69# ax.matshow(attn_matrix, cmap="viridis") 

70# ax.set_title("Raw Attention Pattern") 

71# ax.axis("off") 

72 

73# @register_attn_figure_func 

74# @save_matrix_wrapper(fmt="svg") 

75# def raw_svg(attn_matrix: AttentionMatrix) -> Matrix2D: 

76# return attn_matrix 

77 

78# @register_attn_figure_func 

79# @save_matrix_wrapper(fmt="svgz") 

80# def raw_svgz(attn_matrix: AttentionMatrix) -> Matrix2D: 

81# return attn_matrix