docs for pattern_lens v0.1.0
View Source on GitHub

pattern_lens.load_activations

loading activations from .npz on disk. implements some custom Exception classes


  1"loading activations from .npz on disk. implements some custom Exception classes"
  2
  3import base64
  4import hashlib
  5import json
  6from pathlib import Path
  7from typing import Literal, overload
  8
  9import numpy as np
 10
 11
 12class GetActivationsError(ValueError):
 13    """base class for errors in getting activations"""
 14
 15    pass
 16
 17
 18class ActivationsMissingError(GetActivationsError, FileNotFoundError):
 19    """error for missing activations -- can't find the activations file"""
 20
 21    pass
 22
 23
 24class ActivationsMismatchError(GetActivationsError):
 25    """error for mismatched activations -- the prompt text or hash do not match
 26
 27    raised by `compare_prompt_to_loaded`
 28    """
 29
 30    pass
 31
 32
 33def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
 34    """compare a prompt to a loaded prompt, raise an error if they do not match
 35
 36    # Parameters:
 37     - `prompt : dict`
 38     - `prompt_loaded : dict`
 39
 40    # Returns:
 41     - `None`
 42
 43    # Raises:
 44     - `ActivationsMismatchError` : if the prompt text or hash do not match
 45    """
 46    for key in ("text", "hash"):
 47        if prompt[key] != prompt_loaded[key]:
 48            raise ActivationsMismatchError(
 49                f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
 50            )
 51
 52
 53def augment_prompt_with_hash(prompt: dict) -> dict:
 54    """if a prompt does not have a hash, add one
 55
 56    # Parameters:
 57     - `prompt : dict`
 58
 59    # Returns:
 60     - `dict`
 61
 62    # Modifies:
 63    the input `prompt` dictionary, if it does not have a `"hash"` key
 64    """
 65    if "hash" not in prompt:
 66        prompt_str: str = prompt["text"]
 67        prompt_hash: str = (
 68            base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
 69            .decode()
 70            .rstrip("=")
 71        )
 72        prompt.update(hash=prompt_hash)
 73    return prompt
 74
 75
 76@overload
 77def load_activations(
 78    model_name: str,
 79    prompt: dict,
 80    save_path: Path,
 81    return_fmt: Literal["torch"] = "torch",
 82) -> "tuple[Path, dict[str, torch.Tensor]]":  # type: ignore[name-defined] # noqa: F821
 83    ...
 84@overload
 85def load_activations(
 86    model_name: str,
 87    prompt: dict,
 88    save_path: Path,
 89    return_fmt: Literal["numpy"] = "numpy",
 90) -> "tuple[Path, dict[str, np.ndarray]]": ...
 91def load_activations(
 92    model_name: str,
 93    prompt: dict,
 94    save_path: Path,
 95    return_fmt: Literal["torch", "numpy"] = "torch",
 96) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
 97    """load activations for a prompt and model, from an npz file
 98
 99    # Parameters:
100     - `model_name : str`
101     - `prompt : dict`
102     - `save_path : Path`
103     - `return_fmt : Literal["torch", "numpy"]`
104       (defaults to `"torch"`)
105
106    # Returns:
107     - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
108         the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
109
110    # Raises:
111     - `ActivationsMissingError` : if the activations file is missing
112     - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
113    """
114
115    if return_fmt not in ("torch", "numpy"):
116        raise ValueError(
117            f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
118        )
119    if return_fmt == "torch":
120        import torch
121
122    augment_prompt_with_hash(prompt)
123
124    prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
125    prompt_file: Path = prompt_dir / "prompt.json"
126    if not prompt_file.exists():
127        raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
128    with open(prompt_dir / "prompt.json", "r") as f:
129        prompt_loaded: dict = json.load(f)
130        compare_prompt_to_loaded(prompt, prompt_loaded)
131
132    activations_path: Path = prompt_dir / "activations.npz"
133
134    cache: dict
135
136    with np.load(activations_path) as npz_data:
137        if return_fmt == "numpy":
138            cache = {k: v for k, v in npz_data.items()}
139        elif return_fmt == "torch":
140            cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
141
142    return activations_path, cache

class GetActivationsError(builtins.ValueError):
13class GetActivationsError(ValueError):
14    """base class for errors in getting activations"""
15
16    pass

base class for errors in getting activations

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):
19class ActivationsMissingError(GetActivationsError, FileNotFoundError):
20    """error for missing activations -- can't find the activations file"""
21
22    pass

error for missing activations -- can't find the activations file

Inherited Members
builtins.ValueError
ValueError
builtins.OSError
errno
strerror
filename
filename2
winerror
characters_written
builtins.BaseException
with_traceback
add_note
args
class ActivationsMismatchError(GetActivationsError):
25class ActivationsMismatchError(GetActivationsError):
26    """error for mismatched activations -- the prompt text or hash do not match
27
28    raised by `compare_prompt_to_loaded`
29    """
30
31    pass

error for mismatched activations -- the prompt text or hash do not match

raised by compare_prompt_to_loaded

Inherited Members
builtins.ValueError
ValueError
builtins.BaseException
with_traceback
add_note
args
def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
34def compare_prompt_to_loaded(prompt: dict, prompt_loaded: dict) -> None:
35    """compare a prompt to a loaded prompt, raise an error if they do not match
36
37    # Parameters:
38     - `prompt : dict`
39     - `prompt_loaded : dict`
40
41    # Returns:
42     - `None`
43
44    # Raises:
45     - `ActivationsMismatchError` : if the prompt text or hash do not match
46    """
47    for key in ("text", "hash"):
48        if prompt[key] != prompt_loaded[key]:
49            raise ActivationsMismatchError(
50                f"Prompt file does not match prompt at key {key}:\n{prompt}\n{prompt_loaded}"
51            )

compare a prompt to a loaded prompt, raise an error if they do not match

Parameters:

  • prompt : dict
  • prompt_loaded : dict

Returns:

  • None

Raises:

def augment_prompt_with_hash(prompt: dict) -> dict:
54def augment_prompt_with_hash(prompt: dict) -> dict:
55    """if a prompt does not have a hash, add one
56
57    # Parameters:
58     - `prompt : dict`
59
60    # Returns:
61     - `dict`
62
63    # Modifies:
64    the input `prompt` dictionary, if it does not have a `"hash"` key
65    """
66    if "hash" not in prompt:
67        prompt_str: str = prompt["text"]
68        prompt_hash: str = (
69            base64.urlsafe_b64encode(hashlib.md5(prompt_str.encode()).digest())
70            .decode()
71            .rstrip("=")
72        )
73        prompt.update(hash=prompt_hash)
74    return prompt

if a prompt does not have a hash, add one

Parameters:

  • prompt : dict

Returns:

  • dict

Modifies:

the input prompt dictionary, if it does not have a "hash" key

def load_activations( model_name: str, prompt: dict, save_path: pathlib.Path, return_fmt: Literal['torch', 'numpy'] = 'torch') -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]:
 92def load_activations(
 93    model_name: str,
 94    prompt: dict,
 95    save_path: Path,
 96    return_fmt: Literal["torch", "numpy"] = "torch",
 97) -> "tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]":  # type: ignore[name-defined] # noqa: F821
 98    """load activations for a prompt and model, from an npz file
 99
100    # Parameters:
101     - `model_name : str`
102     - `prompt : dict`
103     - `save_path : Path`
104     - `return_fmt : Literal["torch", "numpy"]`
105       (defaults to `"torch"`)
106
107    # Returns:
108     - `tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]]`
109         the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on `return_fmt`
110
111    # Raises:
112     - `ActivationsMissingError` : if the activations file is missing
113     - `ValueError` : if `return_fmt` is not `"torch"` or `"numpy"`
114    """
115
116    if return_fmt not in ("torch", "numpy"):
117        raise ValueError(
118            f"Invalid return_fmt: {return_fmt}, expected 'torch' or 'numpy'"
119        )
120    if return_fmt == "torch":
121        import torch
122
123    augment_prompt_with_hash(prompt)
124
125    prompt_dir: Path = save_path / model_name / "prompts" / prompt["hash"]
126    prompt_file: Path = prompt_dir / "prompt.json"
127    if not prompt_file.exists():
128        raise ActivationsMissingError(f"Prompt file {prompt_file} does not exist")
129    with open(prompt_dir / "prompt.json", "r") as f:
130        prompt_loaded: dict = json.load(f)
131        compare_prompt_to_loaded(prompt, prompt_loaded)
132
133    activations_path: Path = prompt_dir / "activations.npz"
134
135    cache: dict
136
137    with np.load(activations_path) as npz_data:
138        if return_fmt == "numpy":
139            cache = {k: v for k, v in npz_data.items()}
140        elif return_fmt == "torch":
141            cache = {k: torch.from_numpy(v) for k, v in npz_data.items()}
142
143    return activations_path, cache

load activations for a prompt and model, from an npz file

Parameters:

  • model_name : str
  • prompt : dict
  • save_path : Path
  • return_fmt : Literal["torch", "numpy"] (defaults to "torch")

Returns:

  • tuple[Path, dict[str, torch.Tensor]|dict[str, np.ndarray]] the path to the activations file and the activations as a dictionary of numpy arrays or torch tensors, depending on return_fmt

Raises:

  • ActivationsMissingError : if the activations file is missing
  • ValueError : if return_fmt is not "torch" or "numpy"