pattern_lens.load_activations
loading activations from .npz on disk. implements some custom Exception classes
1"loading activations from .npz on disk. implements some custom Exception classes" 2 3import base64 4import hashlib 5import json 6from pathlib import Path 7from typing import Literal, overload 8 9import numpy as np 10 11 12class GetActivationsError(ValueError): 13 """base class for errors in getting activations""" 14 15 pass 16 17 18class ActivationsMissingError(GetActivationsError, FileNotFoundError): 19 """error for missing activations -- can't find the activations file""" 20 21 pass 22 23 24class ActivationsMismatchError(GetActivationsError): 25 """error for mismatched activations -- the prompt text or hash do not match 26 27 raised by `compare_prompt_to_loaded` 28 """ 29 30 pass 31 32 33def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 34 """compare a prompt to a loaded prompt, raise an error if they do not match 35 36 # Parameters: 37 - `prompt : dict` 38 - `prompt_loaded : dict` 39 40 # Returns: 41 - `None` 42 43 # Raises: 44 - `ActivationsMismatchError` : if the prompt text or hash do not match 45 """ 46 for key in ("text", "hash"): 47 if prompt[key] != prompt_loaded[key]: 48 raise ActivationsMismatchError( 49 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 50 ) 51 52 53def augment_prompt_with_hash(prompt: dict) -> dict: 54 """if a prompt does not have a hash, add one 55 56 # Parameters: 57 - `prompt : dict` 58 59 # Returns: 60 - `dict` 61 62 # Modifies: 63 the input `prompt` dictionary, if it does not have a `"hash"` key 64 """ 65 if "hash" not in prompt: 66 prompt_str: str = prompt["text"] 67 prompt_hash: str = ( 68 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 69 .decode() 70 .rstrip("=") 71 ) 72 prompt.update(hash=prompt_hash) 73 return prompt 74 75 76@overload 77def load_activations( 78 model_name: str, 79 prompt: dict, 80 save_path: Path, 81 return_fmt: Literal["torch"] = "torch", 82) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821 83 ... 84@overload 85def load_activations( 86 model_name: str, 87 prompt: dict, 88 save_path: Path, 89 return_fmt: Literal["numpy"] = "numpy", 90) -> "tuple[Path, dict[str, np.ndarray]]": ... 91def load_activations( 92 model_name: str, 93 prompt: dict, 94 save_path: Path, 95 return_fmt: Literal["torch", "numpy"] = "torch", 96) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 97 """load activations for a prompt and model, from an npz file 98 99 # Parameters: 100 - `model_name : str` 101 - `prompt : dict` 102 - `save_path : Path` 103 - `return_fmt : Literal["torch", "numpy"]` 104 (defaults to `"torch"`) 105 106 # Returns: 107 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 108 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 109 110 # Raises: 111 - `ActivationsMissingError` : if the activations file is missing 112 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 113 """ 114 115 if return_fmt not in ("torch", "numpy"): 116 raise ValueError( 117 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 118 ) 119 if return_fmt == "torch": 120 import torch 121 122 augment_prompt_with_hash(prompt) 123 124 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 125 prompt_file: Path = prompt_dir / "prompt.json" 126 if not prompt_file.exists(): 127 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 128 with open(prompt_dir / "prompt.json", "r") as f: 129 prompt_loaded: dict = json.load(f) 130 compare_prompt_to_loaded(prompt, prompt_loaded) 131 132 activations_path: Path = prompt_dir / "activations.npz" 133 134 cache: dict 135 136 with np.load(activations_path) as npz_data: 137 if return_fmt == "numpy": 138 cache = {k: v for k, v in npz_data.items()} 139 elif return_fmt == "torch": 140 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 141 142 return activations_path, cache
class
GetActivationsError(builtins.ValueError):
13class GetActivationsError(ValueError): 14 """base class for errors in getting activations""" 15 16 pass
base class for errors in getting activations
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
19class ActivationsMissingError(GetActivationsError, FileNotFoundError): 20 """error for missing activations -- can't find the activations file""" 21 22 pass
error for missing activations -- can't find the activations file
Inherited Members
- builtins.ValueError
- ValueError
- builtins.OSError
- errno
- strerror
- filename
- filename2
- winerror
- characters_written
- builtins.BaseException
- with_traceback
- add_note
- args
25class ActivationsMismatchError(GetActivationsError): 26 """error for mismatched activations -- the prompt text or hash do not match 27 28 raised by `compare_prompt_to_loaded` 29 """ 30 31 pass
error for mismatched activations -- the prompt text or hash do not match
raised by compare_prompt_to_loaded
Inherited Members
- builtins.ValueError
- ValueError
- builtins.BaseException
- with_traceback
- add_note
- args
def
compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
34def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None: 35 """compare a prompt to a loaded prompt, raise an error if they do not match 36 37 # Parameters: 38 - `prompt : dict` 39 - `prompt_loaded : dict` 40 41 # Returns: 42 - `None` 43 44 # Raises: 45 - `ActivationsMismatchError` : if the prompt text or hash do not match 46 """ 47 for key in ("text", "hash"): 48 if prompt[key] != prompt_loaded[key]: 49 raise ActivationsMismatchError( 50 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}" 51 )
compare a prompt to a loaded prompt, raise an error if they do not match
Parameters:
prompt : dictprompt_loaded : dict
Returns:
None
Raises:
ActivationsMismatchError: if the prompt text or hash do not match
def
augment_prompt_with_hash(prompt: dict) -> dict:
54def augment_prompt_with_hash(prompt: dict) -> dict: 55 """if a prompt does not have a hash, add one 56 57 # Parameters: 58 - `prompt : dict` 59 60 # Returns: 61 - `dict` 62 63 # Modifies: 64 the input `prompt` dictionary, if it does not have a `"hash"` key 65 """ 66 if "hash" not in prompt: 67 prompt_str: str = prompt["text"] 68 prompt_hash: str = ( 69 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest()) 70 .decode() 71 .rstrip("=") 72 ) 73 prompt.update(hash=prompt_hash) 74 return prompt
if a prompt does not have a hash, add one
Parameters:
prompt : dict
Returns:
dict
Modifies:
the input prompt dictionary, if it does not have a "hash" key
def
load_activations( model_name: str, prompt: dict, save_path: pathlib.Path, return_fmt: Literal['torch', 'numpy'] = 'torch') -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
92def load_activations( 93 model_name: str, 94 prompt: dict, 95 save_path: Path, 96 return_fmt: Literal["torch", "numpy"] = "torch", 97) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821 98 """load activations for a prompt and model, from an npz file 99 100 # Parameters: 101 - `model_name : str` 102 - `prompt : dict` 103 - `save_path : Path` 104 - `return_fmt : Literal["torch", "numpy"]` 105 (defaults to `"torch"`) 106 107 # Returns: 108 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]` 109 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt` 110 111 # Raises: 112 - `ActivationsMissingError` : if the activations file is missing 113 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"` 114 """ 115 116 if return_fmt not in ("torch", "numpy"): 117 raise ValueError( 118 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'" 119 ) 120 if return_fmt == "torch": 121 import torch 122 123 augment_prompt_with_hash(prompt) 124 125 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"] 126 prompt_file: Path = prompt_dir / "prompt.json" 127 if not prompt_file.exists(): 128 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist") 129 with open(prompt_dir / "prompt.json", "r") as f: 130 prompt_loaded: dict = json.load(f) 131 compare_prompt_to_loaded(prompt, prompt_loaded) 132 133 activations_path: Path = prompt_dir / "activations.npz" 134 135 cache: dict 136 137 with np.load(activations_path) as npz_data: 138 if return_fmt == "numpy": 139 cache = {k: v for k, v in npz_data.items()} 140 elif return_fmt == "torch": 141 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()} 142 143 return activations_path, cache
load activations for a prompt and model, from an npz file
Parameters:
model_name : strprompt : dictsave_path : Pathreturn_fmt : Literal["torch", "numpy"](defaults to"torch")
Returns:
tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending onreturn_fmt
Raises:
ActivationsMissingError: if the activations file is missingValueError: ifreturn_fmtis not"torch"or"numpy"