Metadata-Version: 2.4
Name: litfit
Version: 0.1.5
Summary: A lightweight library for fast finetuning of embeddings
Author: Mikhail Kindulov
License: MIT
Project-URL: Homepage, https://github.com/b0nce/litfit
Project-URL: Repository, https://github.com/b0nce/litfit.git
Project-URL: Issues, https://github.com/b0nce/litfit/issues
Keywords: embedding,finetuning,pytorch,numpy
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.11
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: numpy>=1.24.0
Requires-Dist: tqdm>=4.60.0
Provides-Extra: dev
Requires-Dist: pytest>=7.0.0; extra == "dev"
Requires-Dist: ruff>=0.1.0; extra == "dev"
Requires-Dist: mypy>=1.0.0; extra == "dev"
Requires-Dist: black>=23.0.0; extra == "dev"
Requires-Dist: pre-commit>=3.0.0; extra == "dev"
Dynamic: license-file

# litfit

[![CI](https://github.com/b0nce/litfit/actions/workflows/ci.yml/badge.svg)](https://github.com/b0nce/litfit/actions/workflows/ci.yml)
[![PyPI](https://img.shields.io/pypi/v/litfit)](https://pypi.org/project/litfit/)
[![Python](https://img.shields.io/pypi/pyversions/litfit)](https://pypi.org/project/litfit/)
[![License](https://img.shields.io/github/license/b0nce/litfit)](LICENSE)

**litfit** /lɪt fɪt/ — the shortest path from someone else's embedding to your task.

## Why litfit?

Fine-tuning dense embedding models means writing a training loop, picking a loss function, tuning a learning rate, and waiting minutes to hours — whether you're working with text, images, or multimodal embeddings.
litfit takes a different approach: given pairs of items that should be similar (duplicates, relevant matches, same-class images), it computes covariance statistics and solves for the optimal linear projection in closed form. No gradient descent, no hyperparameters to babysit.

**What you get:**
- **Fast** — one pass over your pairs to collect statistics, then everything is solved in closed form. No iterative training = fast.
- **Any dense embeddings** — text, vision, multimodal. If it outputs a vector, litfit can probably improve it.
- **Simple** — pass in embeddings + pair labels, get a well-tuned projection matrix back.

### Benchmarks

| Task | Model | | R@1 | MAP@50 | Dims | Time (CPU) |
|---|---|---|---|---|---|---|
| Fashion retrieval | SigLIP2-SO400M | baseline | 0.833 | 0.532 | 1152 | — |
| (DeepFashion In-Shop) | | + litfit | **0.923** | **0.738** | 228 | 37s |
| Duplicate detection | e5-base-v2 | baseline | 0.556 | 0.522 | 768 | — |
| (AskUbuntu) | | + litfit | **0.598** | **0.590** | 294 | ~3min |

*Embeddings precomputed. DeepFashion In-Shop uses fast mode (~40 configs); AskUbuntu uses full sweep (~860 projections). Closed-form solution means you can safely merge val into training data without risk of overfitting — typically adds a fraction of a second.*

Try it yourself — no setup needed [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/b0nce/litfit/blob/main/notebooks/quickstart.ipynb)

More benchmarks welcome — if you run litfit on your dataset, open an issue or PR with results!

## Installation

```bash
pip install litfit
```

For an editable (development) install:

```bash
pip install -e ".[dev]"
```

For faster statistics computation on CUDA GPUs:

```bash
pip install triton
```

## Usage

```python
from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_all_projections, evaluate_projections,
)

all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1000)
embs = encode_texts("intfloat/e5-base-v2", all_texts)
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids, _, val_embs, _ = data["val"]
test_ids, _, test_embs, _ = data["test"]

st = compute_stats(train_embs, train_ids, id_to_group)
all_W = generate_all_projections(st, neg=None, include_neg_methods=False)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=(0.1, 0.2, 0.5, 1.0),
)
```

### Streaming + fast + dim search (low memory)

Combine streaming statistics, fast projections (~40 configs), lazy evaluation, and
automatic dimension search for a memory-efficient pipeline:

```python
from litfit import (
    compute_stats_streaming, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

def pair_batches():
    for i in range(0, len(X_pairs_memmap), 1024):
        yield X_pairs_memmap[i:i+1024], Y_pairs_memmap[i:i+1024]

st = compute_stats_streaming(pair_batches())
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)
all_W = generate_fast_projections(st, lazy=True)
results, summary = evaluate_projections(
    all_W, test_embs, test_ids, id_to_group,
    dim_fractions=dim_fractions,
)
```

<details>
<summary><strong>Full walkthrough: data concepts, splitting, extracting the best projection</strong></summary>

litfit operates on three data structures:

- **`ids`** — a list of unique identifiers, one per embedding (strings, ints, anything hashable).
- **`id_to_group`** — a dict mapping each id to a group label. Items that share a group are treated as positives (duplicates / paraphrases / relevant matches). Everything else is a negative.
- **`embs`** — a numpy array or torch tensor of shape `(n, d)`, one row per id.

For example, if questions 0, 1, 2 are duplicates and 3, 4 are duplicates:
```python
ids = [0, 1, 2, 3, 4]
id_to_group = {0: 'A', 1: 'A', 2: 'A', 3: 'B', 4: 'B'}
```

Here is a complete pipeline — from loading data to exporting a `torch.nn.Linear`:

```python
import torch
import torch.nn as nn
from litfit import (
    load_askubuntu, encode_texts, split_data,
    compute_stats, generate_fast_projections,
    find_dim_range, evaluate_projections,
)

# --- 1. Load & encode ---
# load_askubuntu returns (ids, texts, id_to_group).
# max_groups limits how many duplicate-groups to keep (for speed).
all_ids, all_texts, id_to_group = load_askubuntu(max_groups=1500)
embs = encode_texts("intfloat/e5-base-v2", all_texts)

# --- 2. Split into train / val / test ---
# split_data does a group-aware split: all items in a group stay together,
# so no group leaks across splits. Default: 60/20/20.
data = split_data(all_ids, all_texts, embs, id_to_group)
train_ids, _, train_embs, _ = data["train"]
val_ids,   _, val_embs,   _ = data["val"]
test_ids,  _, test_embs,  _ = data["test"]

# --- 3. Compute sufficient statistics from training pairs ---
# compute_stats builds covariance matrices (Sigma_XX, Sigma_XY, etc.)
# from all positive pairs implied by id_to_group.
st = compute_stats(train_embs, train_ids, id_to_group)

# --- 4. Find useful dimension range ---
# Scans Rayleigh projections at many dims to find where performance peaks.
# Returns dim_fractions focused on the useful range.
dim_fractions = find_dim_range(st, val_embs, val_ids, id_to_group)

# --- 5. Generate & evaluate projections ---
# generate_fast_projections returns ~40 (method, hyperparams) configs.
# evaluate_projections uses explore-exploit scheduling on the val set.
all_W = generate_fast_projections(st)
results, summary = evaluate_projections(
    all_W, val_embs, val_ids, id_to_group,
    test_embs=test_embs, test_ids=test_ids,
    dim_fractions=dim_fractions,
)

# --- 6. Extract the best projection ---
# results keys are tuples like ('m_rayleigh', 'reg=0.1').
# Each value is {n_dims: {'MAP@50': ..., 'R@1': ..., ...}}.
# n_dims=None means full-dimensional.
best_key = max(results, key=lambda k: results[k][None]['MAP@50'])
W = all_W[best_key]                # shape (d, d) or (d, k)

# Optionally truncate to the best reduced dimension:
best_dim = 128
projected = test_embs @ W[:, :best_dim]  # shape (n, best_dim)

# --- 7. (Optional) Recompute stats on ALL data for best performance ---
# The train split was used for fitting and val/test for model selection.
# Once you've picked the best config, recompute stats on all available
# embeddings so the final projection sees the most signal.
full_st = compute_stats(embs, all_ids, id_to_group)
all_W_full = generate_fast_projections(full_st, verbose=False)
W = all_W_full[best_key]

# --- 8. Export as torch.nn.Linear for inference ---
out_dim = best_dim             # or W.shape[1] for full
layer = nn.Linear(W.shape[0], out_dim, bias=False)
layer.weight = nn.Parameter(W[:, :out_dim].T.cpu().float())
# Use it: projected = layer(input_embs)
```
</details>

See the [docs](docs/) for [more examples](docs/examples.md), [architecture diagrams](docs/architecture.md), and streaming scripts.

## Device Support

- **CUDA**: Full support with optional Triton acceleration
- **CPU**: Full support
- **MPS**: Not supported (missing linalg ops)

## How it works

1. You provide embeddings and group labels (which items are duplicates/relevant/same-class)
2. litfit computes covariance matrices from all positive pairs (the "sufficient statistics")
3. It generates ~40 candidate projections in fast mode, or 800+ in full sweep, using different methods (generalized Rayleigh quotients, CCA-style decompositions, asymmetric refinements, MSE regularization)
5. You get a projection matrix `W` — multiply your embeddings by it and you're done

The result is a linear transformation that can also reduce dimensionality: a 1152-dim SigLIP embedding projected to 228 dims can score *better* than the original on your task.

## Development

```bash
pip install -e ".[dev]"
pytest
mypy litfit
black litfit tests
```
