Metadata-Version: 2.3
Name: flash-hog
Version: 0.1.0
Summary: Add your description here
Requires-Dist: absl-py>=2.4.0
Requires-Dist: chex>=0.1.91
Requires-Dist: einops>=0.8.1
Requires-Dist: equinox>=0.13.2
Requires-Dist: jax[cuda13]>=0.8.0
Requires-Dist: nvidia-cutlass-dsl>=4.2.1
Requires-Dist: pytest>=8.4.2
Requires-Dist: ruff>=0.14.2
Requires-Dist: torch>=2.9.0
Requires-Dist: ty>=0.0.1a24
Requires-Python: >=3.12, <3.14
Description-Content-Type: text/markdown

# Flash Hog
<p align="center">
<img src="assets/logo.png" alt="Flash Hog Logo" width="256" />
</p>

This repo contains the code for Flash Higher-Order-Gradients, aka. Flash Hog.
This kernel achieves around a 3.7x speedup over an XLA optimized kernel, with linear memory scaling instead of quadratic scaling.

<p align="center">
<img src="assets/speedup.png" alt="Hog Speedup" width="512"/>
</p>

## Installation
TODO

## Method
Flash Hog does 4 recomputation passes to avoid any atomics or saving any intermediary tensors of shape `(N_Q, N_K)`.
This shakes out to be thread-wise tiling across Q in 3 passes first, once to compute `dd`, then once for `b`, then once for both `dQ'` and `ddO`.
Finally we do another pass tiled over K, producing `dK'` and `dV'`.
The equations we implement are the following:


<p align="center">
<img src="assets/handwritten_equations.png" alt="Equations" width="512"/>
</p>