Coverage for pattern_lens\load_activations.py: 85%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2025-01-03 23:21 -0700

1"loading activations from .npz on disk. implements some custom Exception classes" 

2 

3import base64 

4import hashlib 

5import json 

6from pathlib import Path 

7from typing import Literal, overload 

8 

9import numpy as np 

10 

11 

12class GetActivationsError(ValueError): 

13 """base class for errors in getting activations""" 

14 

15 pass 

16 

17 

18class ActivationsMissingError(GetActivationsError, FileNotFoundError): 

19 """error for missing activations -- can't find the activations file""" 

20 

21 pass 

22 

23 

24class ActivationsMismatchError(GetActivationsError): 

25 """error for mismatched activations -- the prompt text or hash do not match 

26 

27 raised by `compare_prompt_to_loaded` 

28 """ 

29 

30 pass 

31 

32 

33def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 

34 """compare a prompt to a loaded prompt, raise an error if they do not match 

35 

36 # Parameters: 

37 - `prompt : dict` 

38 - `prompt_loaded : dict` 

39 

40 # Returns: 

41 - `None` 

42 

43 # Raises: 

44 - `ActivationsMismatchError` : if the prompt text or hash do not match 

45 """ 

46 for key in ("text", "hash"): 

47 if prompt[key] != prompt_loaded[key]: 

48 raise ActivationsMismatchError( 

49 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 

50 ) 

51 

52 

53def augment_prompt_with_hash(prompt: dict) -> dict: 

54 """if a prompt does not have a hash, add one 

55 

56 # Parameters: 

57 - `prompt : dict` 

58 

59 # Returns: 

60 - `dict` 

61 

62 # Modifies: 

63 the input `prompt` dictionary, if it does not have a `"hash"` key 

64 """ 

65 if "hash" not in prompt: 

66 prompt_str: str = prompt["text"] 

67 prompt_hash: str = ( 

68 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 

69 .decode() 

70 .rstrip("=") 

71 ) 

72 prompt.update(hash=prompt_hash) 

73 return prompt 

74 

75 

76@overload 

77def load_activations( 

78 model_name: str, 

79 prompt: dict, 

80 save_path: Path, 

81 return_fmt: Literal["torch"] = "torch", 

82) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 

83 ... 

84@overload 

85def load_activations( 

86 model_name: str, 

87 prompt: dict, 

88 save_path: Path, 

89 return_fmt: Literal["numpy"] = "numpy", 

90) -> "tuple[Path, dict[str, np.ndarray]]": ... 

91def load_activations( 

92 model_name: str, 

93 prompt: dict, 

94 save_path: Path, 

95 return_fmt: Literal["torch", "numpy"] = "torch", 

96) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 

97 """load activations for a prompt and model, from an npz file 

98 

99 # Parameters: 

100 - `model_name : str` 

101 - `prompt : dict` 

102 - `save_path : Path` 

103 - `return_fmt : Literal["torch", "numpy"]` 

104 (defaults to `"torch"`) 

105 

106 # Returns: 

107 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 

108 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 

109 

110 # Raises: 

111 - `ActivationsMissingError` : if the activations file is missing 

112 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 

113 """ 

114 

115 if return_fmt not in ("torch", "numpy"): 

116 raise ValueError( 

117 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 

118 ) 

119 if return_fmt == "torch": 

120 import torch 

121 

122 augment_prompt_with_hash(prompt) 

123 

124 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 

125 prompt_file: Path = prompt_dir / "prompt.json" 

126 if not prompt_file.exists(): 

127 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 

128 with open(prompt_dir / "prompt.json", "r") as f: 

129 prompt_loaded: dict = json.load(f) 

130 compare_prompt_to_loaded(prompt, prompt_loaded) 

131 

132 activations_path: Path = prompt_dir / "activations.npz" 

133 

134 cache: dict 

135 

136 with np.load(activations_path) as npz_data: 

137 if return_fmt == "numpy": 

138 cache = {k: v for k, v in npz_data.items()} 

139 elif return_fmt == "torch": 

140 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 

141 

142 return activations_path, cache