Coverage for pattern_lens\load_activations.py: 85%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
« prev ^ index » next coverage.py v7.6.9, created at 2025-01-03 23:21 -0700
1"loading activations from .npz on disk. implements some custom Exception classes"
3import base64
4import hashlib
5import json
6from pathlib import Path
7from typing import Literal, overload
9import numpy as np
12class GetActivationsError(ValueError):
13 """base class for errors in getting activations"""
15 pass
18class ActivationsMissingError(GetActivationsError, FileNotFoundError):
19 """error for missing activations -- can't find the activations file"""
21 pass
24class ActivationsMismatchError(GetActivationsError):
25 """error for mismatched activations -- the prompt text or hash do not match
27 raised by `compare_prompt_to_loaded`
28 """
30 pass
33def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
34 """compare a prompt to a loaded prompt, raise an error if they do not match
36 # Parameters:
37 - `prompt : dict`
38 - `prompt_loaded : dict`
40 # Returns:
41 - `None`
43 # Raises:
44 - `ActivationsMismatchError` : if the prompt text or hash do not match
45 """
46 for key in ("text", "hash"):
47 if prompt[key] != prompt_loaded[key]:
48 raise ActivationsMismatchError(
49 f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
50 )
53def augment_prompt_with_hash(prompt: dict) -> dict:
54 """if a prompt does not have a hash, add one
56 # Parameters:
57 - `prompt : dict`
59 # Returns:
60 - `dict`
62 # Modifies:
63 the input `prompt` dictionary, if it does not have a `"hash"` key
64 """
65 if "hash" not in prompt:
66 prompt_str: str = prompt["text"]
67 prompt_hash: str = (
68 base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
69 .decode()
70 .rstrip("=")
71 )
72 prompt.update(hash=prompt_hash)
73 return prompt
76@overload
77def load_activations(
78 model_name: str,
79 prompt: dict,
80 save_path: Path,
81 return_fmt: Literal["torch"] = "torch",
82) -> "tuple[Path, dict[str, torch.Tensor]]": # type: ignore[name-defined] # noqa: F821
83 ...
84@overload
85def load_activations(
86 model_name: str,
87 prompt: dict,
88 save_path: Path,
89 return_fmt: Literal["numpy"] = "numpy",
90) -> "tuple[Path, dict[str, np.ndarray]]": ...
91def load_activations(
92 model_name: str,
93 prompt: dict,
94 save_path: Path,
95 return_fmt: Literal["torch", "numpy"] = "torch",
96) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]": # type: ignore[name-defined] # noqa: F821
97 """load activations for a prompt and model, from an npz file
99 # Parameters:
100 - `model_name : str`
101 - `prompt : dict`
102 - `save_path : Path`
103 - `return_fmt : Literal["torch", "numpy"]`
104 (defaults to `"torch"`)
106 # Returns:
107 - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
108 the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
110 # Raises:
111 - `ActivationsMissingError` : if the activations file is missing
112 - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
113 """
115 if return_fmt not in ("torch", "numpy"):
116 raise ValueError(
117 f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
118 )
119 if return_fmt == "torch":
120 import torch
122 augment_prompt_with_hash(prompt)
124 prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
125 prompt_file: Path = prompt_dir / "prompt.json"
126 if not prompt_file.exists():
127 raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
128 with open(prompt_dir / "prompt.json", "r") as f:
129 prompt_loaded: dict = json.load(f)
130 compare_prompt_to_loaded(prompt, prompt_loaded)
132 activations_path: Path = prompt_dir / "activations.npz"
134 cache: dict
136 with np.load(activations_path) as npz_data:
137 if return_fmt == "numpy":
138 cache = {k: v for k, v in npz_data.items()}
139 elif return_fmt == "torch":
140 cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
142 return activations_path, cache