docs for pattern_lens v0.1.0

Contents

pattern-lens

visualization of LLM attention patterns and things computed about them

pattern-lens makes it easy to:

Installation

pip install pattern-lens

Usage

The pipeline is as follows:

Basic CLI

Generate attention patterns and default visualizations:

# generate activations
python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
# create visualizations
python -m pattern_lens.figures --model gpt2 --save-path attn_data

serve the web UI:

python -m pattern_lens.server --path attn_data

Web UI

View a demo of the web UI at miv.name/pattern-lens/demo.

Custom Figures

Add custom visualization functions by decorating them with @register_attn_figure_func. You should still generate the activations first:

python -m pattern_lens.activations --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data

and then write+run a script/notebook that looks something like this (see demo.ipynb for a full example):

# imports
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

from pattern_lens.figure_util import matplotlib_figure_saver, save_matrix_wrapper
from pattern_lens.attn_figure_funcs import register_attn_figure_func
from pattern_lens.figures import figures_main

# define your own functions
# this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


# run the figures pipelne
# run the pipeline
figures_main(
    model_name="pythia-14m",
    save_path=Path("docs/demo/"),
    n_samples=5,
    force=False,
)

Submodules

View Source on GitHub

pattern_lens

pattern-lens

visualization of LLM attention patterns and things computed about them

pattern-lens makes it easy to:

Installation

pip install pattern-lens

Usage

The pipeline is as follows:

Basic CLI

Generate attention patterns and default visualizations:

### generate activations
python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data
### create visualizations
python -m <a href="pattern_lens/figures.html">pattern_lens.figures</a> --model gpt2 --save-path attn_data

serve the web UI:

python -m <a href="pattern_lens/server.html">pattern_lens.server</a> --path attn_data

Web UI

View a demo of the web UI at miv.name/pattern-lens/demo.

Custom Figures

Add custom visualization functions by decorating them with @register_attn_figure_func. You should still generate the activations first:

python -m <a href="pattern_lens/activations.html">pattern_lens.activations</a> --model gpt2 --prompts data/pile_1k.jsonl --save-path attn_data

and then write+run a script/notebook that looks something like this (see demo.ipynb for a full example):

### imports
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import svd

from <a href="pattern_lens/figure_util.html">pattern_lens.figure_util</a> import matplotlib_figure_saver, save_matrix_wrapper
from <a href="pattern_lens/attn_figure_funcs.html">pattern_lens.attn_figure_funcs</a> import register_attn_figure_func
from <a href="pattern_lens/figures.html">pattern_lens.figures</a> import figures_main

### define your own functions
### this one uses `matplotlib_figure_saver` -- define a function that takes matrix and `plt.Axes`, modify the axes
@register_attn_figure_func
@matplotlib_figure_saver(fmt="svgz")
def svd_spectra(attn_matrix: np.ndarray, ax: plt.Axes) -> None:
    # Perform SVD
    U, s, Vh = svd(attn_matrix)

    # Plot singular values
    ax.plot(s, "o-")
    ax.set_yscale("log")
    ax.set_xlabel("Singular Value Index")
    ax.set_ylabel("Singular Value")
    ax.set_title("Singular Value Spectrum of Attention Matrix")


### run the figures pipelne
### run the pipeline
figures_main(
    model_name="pythia-14m",
    save_path=Path("docs/demo/"),
    n_samples=5,
    force=False,
)

View Source on GitHub

docs for pattern_lens v0.1.0

Contents

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m pattern_lens.activations --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from pattern_lens.activations import activations_main
activations_main(
    model_name="gpt2",
    save_path="demo/"
    prompts_path="data/pile_1k.jsonl",
)

API Documentation

View Source on GitHub

pattern_lens.activations

computing and saving activations given a model and prompts

Usage:

from the command line:

python -m <a href="">pattern_lens.activations</a> --model <model_name> --prompts <prompts_path> --save-path <save_path> --min-chars <min_chars> --max-chars <max_chars> --n-samples <n_samples>

from a script:

from <a href="">pattern_lens.activations</a> import activations_main
activations_main(
    model_name="gpt2",
    save_path="demo/"
    prompts_path="data/pile_1k.jsonl",
)

View Source on GitHub

def compute_activations

(
    prompt: dict,
    model: transformer_lens.HookedTransformer.HookedTransformer | None = None,
    save_path: pathlib.Path = WindowsPath('attn_data'),
    return_cache: bool = True,
    names_filter: Union[Callable[[str], bool], re.Pattern] = re.compile('blocks\\.(\\d+)\\.attn\\.hook_pattern')
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]

View Source on GitHub

get activations for a given model and prompt, possibly from a cache

if from a cache, prompt_meta must be passed and contain the prompt hash

Parameters:

Returns:

def get_activations

(
    prompt: dict,
    model: transformer_lens.HookedTransformer.HookedTransformer | str,
    save_path: pathlib.Path = WindowsPath('attn_data'),
    allow_disk_cache: bool = True,
    return_cache: bool = True
) -> tuple[pathlib.Path, dict[str, numpy.ndarray] | None]

View Source on GitHub

given a prompt and a model, save or load activations

Parameters:

Returns:

def activations_main

(
    model_name: str,
    save_path: str,
    prompts_path: str,
    raw_prompts: bool,
    min_chars: int,
    max_chars: int,
    force: bool,
    n_samples: int,
    no_index_html: bool,
    shuffle: bool = False
) -> None

View Source on GitHub

main function for computing activations

Parameters:

def main

()

View Source on GitHub

docs for pattern_lens v0.1.0

Contents

default figure functions

note that for pattern_lens.figures to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS

API Documentation

View Source on GitHub

pattern_lens.attn_figure_funcs

default figure functions

note that for <a href="figures.html">pattern_lens.figures</a> to recognize your function, you need to use the register_attn_figure_func decorator which adds your function to ATTENTION_MATRIX_FIGURE_FUNCS

View Source on GitHub

def register_attn_figure_func

(
    func: Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]
) -> Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]

View Source on GitHub

decorator for registering attention matrix figure function

if you want to add a new figure function, you should use this decorator

# Parameters:
 - `func : AttentionMatrixFigureFunc`
   your function, which should take an attention matrix and path

# Returns:
 - `AttentionMatrixFigureFunc`
   your function, after we add it to `ATTENTION_MATRIX_FIGURE_FUNCS`

Usage:

@register_attn_figure_func
def my_new_figure_func(attn_matrix: AttentionMatrix, path: Path) -> None:
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("My New Figure Function")
    ax.axis("off")
    plt.savefig(path / "my_new_figure_func", format="svgz")
    plt.close(fig)

def raw

(
    attn_matrix: jaxtyping.Float[ndarray, 'n_ctx n_ctx']
) -> jaxtyping.Float[ndarray, 'n m']

View Source on GitHub

docs for pattern_lens v0.1.0

Contents

implements some constants and types

API Documentation

View Source on GitHub

pattern_lens.consts

implements some constants and types

View Source on GitHub

type alias for attention matrix

type alias for a cache of attention matrices, subset of ActivationCache

default directory for attention data

regex for finding attention patterns in model state dicts

default kwargs for muutils.spinner.Spinner

divider string for separating sections

divider string for separating subsections

docs for pattern_lens v0.1.0

Contents

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures

API Documentation

View Source on GitHub

pattern_lens.figure_util

implements a bunch of types, default values, and templates which are useful for figure functions

notably, you can use the decorators matplotlib_figure_saver, save_matrix_wrapper to make your functions save figures

View Source on GitHub

Type alias for a function that, given an attention matrix, saves a figure

Type alias for a 2D matrix (plottable)

Type alias for a 2D matrix with 3 channels (RGB)

Type alias for a function that, given an attention matrix, returns a 2D matrix

format for saving matplotlib figures

Type alias for the format to save a matrix as when saving raw matrix, not matplotlib figure

default for whether to normalize the matrix to range [0, 1]

default colormap for saving matrices

default format for saving matrices

template for saving an n by m matrix as an svg/svgz

def matplotlib_figure_saver

(
    func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType]] = None,
    *args,
    fmt: str = 'svgz'
) -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], matplotlib.axes._axes.Axes], NoneType], str], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]

View Source on GitHub

decorator for functions which take an attention matrix and predefined ax object, making it save a figure

Parameters:

Returns:

Usage:

@register_attn_figure_func
@matplotlib_figure_saver
def raw(attn_matrix: AttentionMatrix, ax: plt.Axes) -> None:
    ax.matshow(attn_matrix, cmap="viridis")
    ax.set_title("Raw Attention Pattern")
    ax.axis("off")

def matrix_to_image_preprocess

(
    matrix: jaxtyping.Float[ndarray, 'n m'],
    normalize: bool = False,
    cmap: str | matplotlib.colors.Colormap = 'viridis'
) -> jaxtyping.UInt8[ndarray, 'n m rgb=3']

View Source on GitHub

preprocess a 2D matrix into a plottable heatmap image

Parameters:

Returns:

def matrix2drgb_to_png_bytes

(
    matrix: jaxtyping.UInt8[ndarray, 'n m rgb=3'],
    buffer: _io.BytesIO | None = None
) -> bytes | None

View Source on GitHub

Convert a Matrix2Drgb to valid PNG bytes via PIL

Parameters:

Returns:

def matrix_as_svg

(
    matrix: jaxtyping.Float[ndarray, 'n m'],
    normalize: bool = False,
    cmap='viridis'
) -> str

View Source on GitHub

quickly convert a 2D matrix to an SVG image, without matplotlib

Parameters:

Returns:

def save_matrix_wrapper

(
    func: Optional[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]] = None,
    *args,
    fmt: Literal['png', 'svg', 'svgz'] = 'svgz',
    normalize: bool = False,
    cmap='viridis'
) -> Union[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType], Callable[[Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx']], jaxtyping.Float[ndarray, 'n m']]], Callable[[jaxtyping.Float[ndarray, 'n_ctx n_ctx'], pathlib.Path], NoneType]]]

View Source on GitHub

Decorator for functions that process an attention matrix and save it as an SVGZ image. Can handle both argumentless usage and with arguments.

Parameters:

Returns:

AttentionMatrixFigureFunc|Callable[[AttentionMatrixToMatrixFunc], AttentionMatrixFigureFunc]

Usage:

@save_matrix_wrapper
def identity_matrix(matrix):
    return matrix

@save_matrix_wrapper(normalize=True, fmt="png")
def scale_matrix(matrix):
    return matrix * 2

@save_matrix_wrapper(normalize=True, cmap="plasma")
def scale_matrix(matrix):
    return matrix * 2

docs for pattern_lens v0.1.0

Contents

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func

API Documentation

View Source on GitHub

pattern_lens.figures

code for generating figures from attention patterns, using the functions decorated with register_attn_figure_func

View Source on GitHub

class HTConfigMock:

View Source on GitHub

Mock of transformer_lens.HookedTransformerConfig for type hinting and loading config json

can be initialized with any kwargs, and will update its __dict__ with them. does, however, require the following attributes: - n_layers: int - n_heads: int - model_name: str

HTConfigMock

(**kwargs)

View Source on GitHub

def serialize

(self)

View Source on GitHub

serialize the config to json. values which aren’t serializable will be converted via muutils.json_serialize.json_serialize

def load

(cls, data: dict)

View Source on GitHub

try to load a config from a dict, using the __init__ method

def process_single_head

(
    layer_idx: int,
    head_idx: int,
    attn_pattern: jaxtyping.Float[ndarray, 'n_ctx n_ctx'],
    save_dir: pathlib.Path,
    force_overwrite: bool = False
) -> dict[str, bool | Exception]

View Source on GitHub

process a single head’s attention pattern, running all the functions in ATTENTION_MATRIX_FIGURE_FUNCS on the attention pattern

Parameters:

Returns:

def compute_and_save_figures

(
    model_cfg: 'HookedTransformerConfig|HTConfigMock',
    activations_path: pathlib.Path,
    cache: dict[str, numpy.ndarray],
    save_path: pathlib.Path = WindowsPath('attn_data'),
    force_overwrite: bool = False,
    track_results: bool = False
) -> None

View Source on GitHub

compute and save figures for all heads in the model, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

def process_prompt

(
    prompt: dict,
    model_cfg: 'HookedTransformerConfig|HTConfigMock',
    save_path: pathlib.Path,
    force_overwrite: bool = False
) -> None

View Source on GitHub

process a single prompt, loading the activations and computing and saving the figures

basically just calls load_activations and then compute_and_save_figures

Parameters:

def figures_main

(
    model_name: str,
    save_path: str,
    n_samples: int,
    force: bool,
    parallel: bool | int = True
) -> None

View Source on GitHub

main function for generating figures from attention patterns, using the functions in ATTENTION_MATRIX_FIGURE_FUNCS

Parameters:

def main

()

View Source on GitHub

docs for pattern_lens v0.1.0

Contents

writes indexes to the model directory for the frontend to use or for record keeping

API Documentation

View Source on GitHub

pattern_lens.indexes

writes indexes to the model directory for the frontend to use or for record keeping

View Source on GitHub

def generate_prompts_jsonl

(model_dir: pathlib.Path)

View Source on GitHub

creates a prompts.jsonl file with all the prompts in the model directory

looks in all directories in {model_dir}/prompts for a prompt.json file

def generate_models_jsonl

(path: pathlib.Path)

View Source on GitHub

creates a models.jsonl file with all the models

def get_func_metadata

(func: Callable) -> dict[str, str | None]

View Source on GitHub

get metadata for a function

Parameters:

Returns:

dict[str, str | None] dictionary:

def generate_functions_jsonl

(path: pathlib.Path)

View Source on GitHub

unions all functions from file and current ATTENTION_MATRIX_FIGURE_FUNCS into a functions.jsonl file

def write_html_index

(path: pathlib.Path)

View Source on GitHub

writes an index.html file to the path

docs for pattern_lens v0.1.0

Contents

loading activations from .npz on disk. implements some custom Exception classes

API Documentation

View Source on GitHub

pattern_lens.load_activations

loading activations from .npz on disk. implements some custom Exception classes

View Source on GitHub

class GetActivationsError(builtins.ValueError):

View Source on GitHub

base class for errors in getting activations

Inherited Members

class ActivationsMissingError(GetActivationsError, builtins.FileNotFoundError):

View Source on GitHub

error for missing activations – can’t find the activations file

Inherited Members

class ActivationsMismatchError(GetActivationsError):

View Source on GitHub

error for mismatched activations – the prompt text or hash do not match

raised by compare_prompt_to_loaded

Inherited Members

def compare_prompt_to_loaded

(prompt: dict, prompt_loaded: dict) -> None

View Source on GitHub

compare a prompt to a loaded prompt, raise an error if they do not match

Parameters:

Returns:

Raises:

def augment_prompt_with_hash

(prompt: dict) -> dict

View Source on GitHub

if a prompt does not have a hash, add one

Parameters:

Returns:

Modifies:

the input prompt dictionary, if it does not have a "hash" key

def load_activations

(
    model_name: str,
    prompt: dict,
    save_path: pathlib.Path,
    return_fmt: Literal['torch', 'numpy'] = 'torch'
) -> tuple[pathlib.Path, dict[str, torch.Tensor] | dict[str, numpy.ndarray]]

View Source on GitHub

load activations for a prompt and model, from an npz file

Parameters:

Returns:

Raises:

docs for pattern_lens v0.1.0

Contents

implements load_text_data for loading prompts

API Documentation

View Source on GitHub

pattern_lens.prompts

implements load_text_data for loading prompts

View Source on GitHub

def load_text_data

(
    fname: pathlib.Path,
    min_chars: int | None = None,
    max_chars: int | None = None,
    shuffle: bool = False
) -> list[dict]

View Source on GitHub

given fname, the path to a jsonl file, split prompts up into more reasonable sizes

Parameters:

Returns:

docs for pattern_lens v0.1.0

Contents

cli for starting the server to show the web ui.

can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.

API Documentation

View Source on GitHub

pattern_lens.server

cli for starting the server to show the web ui.

can also run with –rewrite-index to update the index.html file. this is useful for working on the ui.

View Source on GitHub

def main

(path: str, port: int = 8000)

View Source on GitHub