Metadata-Version: 2.4
Name: zmlx
Version: 0.10.0
Summary: ZMLX: Metal-kernel toolkit and optimization lab for MLX on Apple Silicon. Fused MoE decode (+2-12% on LFM2/Qwen3.5), custom GPU kernels in one line, 70+ kernel catalog.
Project-URL: Homepage, https://github.com/Hmbown/ZMLX
Project-URL: Repository, https://github.com/Hmbown/ZMLX
Project-URL: Documentation, https://github.com/Hmbown/ZMLX#readme
Project-URL: Issues, https://github.com/Hmbown/ZMLX/issues
Project-URL: Changelog, https://github.com/Hmbown/ZMLX/blob/main/CHANGELOG.md
Author: Hunter Bown
License: MIT
License-File: LICENSE
Keywords: apple-silicon,autograd,jit,kernels,metal,mlx,zmlx
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Requires-Python: >=3.10
Requires-Dist: mlx>=0.30.0
Provides-Extra: dev
Requires-Dist: mypy>=1.10.0; extra == 'dev'
Requires-Dist: numpy>=1.24.0; extra == 'dev'
Requires-Dist: pytest>=8.0.0; extra == 'dev'
Requires-Dist: pyyaml>=6.0; extra == 'dev'
Requires-Dist: ruff>=0.5.0; extra == 'dev'
Requires-Dist: types-pyyaml>=6.0; extra == 'dev'
Provides-Extra: discover
Requires-Dist: anthropic>=0.40.0; extra == 'discover'
Requires-Dist: openai>=1.50.0; extra == 'discover'
Provides-Extra: kvtc
Requires-Dist: numpy>=1.24.0; extra == 'kvtc'
Provides-Extra: lm
Requires-Dist: mlx-lm>=0.25.0; extra == 'lm'
Description-Content-Type: text/markdown

# ZMLX — Metal kernels and model patching for MLX on Apple Silicon

[![PyPI](https://img.shields.io/pypi/v/zmlx.svg)](https://pypi.org/project/zmlx/)
[![Python 3.10+](https://img.shields.io/badge/python-3.10%2B-blue.svg)](https://www.python.org/downloads/)
[![License: MIT](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
[![Platform: macOS Apple Silicon](https://img.shields.io/badge/platform-macOS%20Apple%20Silicon-lightgrey.svg)](https://github.com/ml-explore/mlx)

ZMLX extends [MLX](https://github.com/ml-explore/mlx) with a Python-first Metal kernel toolkit and model-aware patching for faster MoE decode on Apple Silicon.

**What ZMLX does**

- **Metal kernels from Python:** write `elementwise("x * tanh(log(1 + exp(x)))")` and get a compiled Metal kernel with caching, autograd support, and the 70+ kernel catalog.
- **Model patching:** `patch(model)` replaces MoE gating/combine/activation sequences with fused Metal kernels, reducing dispatch overhead during decode. Token-identical output; verify with `python -m zmlx.validate`.
- **Works with stock MLX:** LFM2-8B (+12%) and LFM2-24B (+7%) show consistent decode gains with `pip install mlx` — no custom builds required.
- **Qwen3.5-35B-A3B support (new):** `patch(model)` auto-detects Qwen3.5's hybrid DeltaNet+Attention MoE architecture and applies fused MoE decode. ~+2% decode on M4 Max 36GB, token-identical. Your results may vary depending on hardware.
- **Optional custom primitive (GLM/Qwen3):** build the custom `gather_qmm_swiglu` primitive to fuse quantized expert projections for GLM-4.7-Flash and Qwen3-30B-A3B. See [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md). On stock MLX these models auto-skip safely.

## Measured Results

All numbers below are on **M4 Max 36GB** with greedy decoding. Your results will vary depending on hardware, thermal state, and prompt length. Verify on your machine with `python -m zmlx.validate <model>`.

### Stock MLX (works with `pip install mlx`)

| Model | Decode | Prefill | Fidelity |
|:--|--:|--:|:--|
| LFM2-8B-A1B-4bit | **+12.8%** (197.8 -> 223.2 tok/s) | neutral | token-identical |
| LFM2-24B-A2B-4bit | **+6.0%** (152.0 -> 161.1 tok/s) | neutral | token-identical |
| Qwen3.5-35B-A3B-4bit | **~+2%** (~36.2 -> ~36.8 tok/s) | **~+4%** | token-identical |
| GPT-OSS-20B-4bit | +1.0% (121.8 -> 122.9 tok/s) | neutral | token-identical |

### Custom MLX primitive (requires building `mlx_local/`)

| Model | Decode | Change | Fidelity |
|:--|--:|--:|:--|
| GLM-4.7-Flash-4bit | +6.2% (200 tok), +6.7% (1024 tok) | **~+6.4%** | PASS |

See [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md) for build instructions.

Full methodology, raw data, and repro capsules: [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md) and `benchmarks/repro_capsules/`.

## Quick Start

**Requirements:** macOS 14+ (Apple Silicon), Python >= 3.10, `mlx>=0.30.0`

1. Install (patching examples use `mlx-lm`):

```bash
pip install "zmlx[lm]"       # includes mlx-lm for model patching
# pip install zmlx            # kernel authoring only
```

2. Patch a model and generate (no weight conversion; patches apply in-place):

```python
import mlx_lm
from zmlx.patch import patch

# Works with any supported model — just change the model ID
model, tokenizer = mlx_lm.load("LiquidAI/LFM2-24B-A2B-MLX-4bit")
patch(model)  # auto-detects model family, applies safe optimizations

print(
    mlx_lm.generate(
        model,
        tokenizer,
        prompt="Explain mixture-of-experts in one paragraph.",
        max_tokens=200,
    )
)
```

That's it. `patch(model)` handles everything automatically — model detection, kernel selection, and safety checks. No env vars or configuration needed.

3. Verify token fidelity + throughput on your hardware:

```bash
# LFM2-24B (+7% on M4 Max)
python -m zmlx.validate LiquidAI/LFM2-24B-A2B-MLX-4bit --max-tokens 200 --runs 3

# LFM2-8B (+12% on M4 Max)
python -m zmlx.validate mlx-community/LFM2-8B-A1B-4bit --max-tokens 200 --runs 3
```

One-command smoke inference (loads model, applies `zmlx.patch.patch(model)`, then generates):

```bash
source .venv/bin/activate && python examples/inference_smoke.py --model-id <model> --prompt "<prompt>" --max-tokens 64
```

Expected output shape:
- `[load] model=<model>`
- `[patch] Applying zmlx.patch.patch(model) with safe defaults`
- `[patch] Patched ...`
- `[generate] prompt='...' max_tokens=64`
- `[output]` followed by generated text

Tip: large model downloads use the Hugging Face cache; set `HF_HOME` to control its location.

## What's Inside

- **Model patching:** `zmlx.patch.patch()` (preset-based) and `zmlx.patch.smart_patch()` (auto-benchmark patterns).
- **Kernel authoring:** `zmlx.api.elementwise()`, `reduce()`, `map_reduce()`, and `@zmlx.jit`.
- **Autograd support:** optional custom VJP paths via MLX custom functions.
- **Benchmarking:** `zmlx.bench.compare()` and `python -m zmlx.bench.report` (repro capsules in `benchmarks/repro_capsules/`).
- **Custom MLX primitive (opt-in):** build a custom MLX with `gather_qmm_swiglu` (see [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md); patch lives in `integrations/mlx_local_integration/`).

## exo Integration

ZMLX works with [exo](https://github.com/exo-explore/exo) for faster GLM-4.7-Flash and Qwen3-30B-A3B decode. No source patching needed.

From a ZMLX checkout (recommended; clones exo into `./exo` and generates `exo/run_zmlx.sh`):

```bash
bash setup_zmlx.sh
bash exo/run_zmlx.sh
```

If `exo` is already installed in your environment:

```bash
pip install zmlx
zmlx-exo
```

For GLM/Qwen3 speedups, first build the optional custom MLX primitive (`gather_qmm_swiglu`) per [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md), then re-run `bash setup_zmlx.sh` so the exo venv picks it up.

ZMLX hooks into exo's model loading at runtime — when GLM/Qwen3 load with the custom MLX primitive, MoE expert dispatch is fused. Measured speedups vary by prompt/length; see [`docs/EXO.md`](docs/EXO.md) and repro capsules in `benchmarks/repro_capsules/`.

## Docs

| Doc | What's inside |
|:--|:--|
| [`docs/TOUR.md`](docs/TOUR.md) | Quick walkthrough and how to verify results |
| [`docs/QUICKSTART.md`](docs/QUICKSTART.md) | 5-minute kernel authoring tutorial |
| [`docs/COOKBOOK.md`](docs/COOKBOOK.md) | Recipes for common patterns |
| [`docs/KERNELS.md`](docs/KERNELS.md) | Kernel catalog (by module/domain) |
| [`docs/KNOWLEDGE_BASE.md`](docs/KNOWLEDGE_BASE.md) | Canonical KB schema, rebuild, and validation |
| [`docs/FOUNDRY.md`](docs/FOUNDRY.md) | Kernel template evaluation, dataset generation, SFT export |
| [`docs/kernel_discovery.md`](docs/kernel_discovery.md) | Hamiltonian-guided fused-boundary kernel discovery (`zmlx.kd`) |
| [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md) | Benchmark methodology + raw data |
| [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md) | Design philosophy |
| [`docs/EXO.md`](docs/EXO.md) | exo integration guide (GLM/Qwen3) |
| [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md) | Custom MLX primitive details |
| [`UPSTREAM_PLAN.md`](UPSTREAM_PLAN.md) | What belongs upstream in MLX |

## Contributing / Development

See [`CONTRIBUTING.md`](CONTRIBUTING.md) for setup, testing, and conventions.

```bash
git clone https://github.com/Hmbown/ZMLX.git
cd ZMLX
pip install -e ".[dev]"
pytest
```

---

<details>
<summary>Benchmarks (stock MLX — works with pip install mlx)</summary>

These results use **released MLX** (`pip install mlx`). The speedup comes from ZMLX's own Python-level Metal kernels (fused gating, combine, SwiGLU activation) — no custom C++ or MLX fork required.

Full methodology and raw data: [`docs/BENCHMARKS.md`](docs/BENCHMARKS.md).

| Model | Hardware | Decode (baseline -> patched) | Change | Fidelity | Capsule |
|:--|:--|--:|--:|:--|:--|
| LFM2-8B-A1B-4bit | M4 Max 36 GB | 197.8 tok/s -> 223.2 tok/s | **+12.8%** | token-identical | [`benchmarks/repro_capsules/lfm2_m4max_20260205_rerun_mlx0304dev2f324cc.json`](benchmarks/repro_capsules/lfm2_m4max_20260205_rerun_mlx0304dev2f324cc.json) |
| LFM2-8B-A1B-4bit | M1 Pro 16 GB | 105.5 tok/s -> 115.3 tok/s | +9.3% | token-identical | [`benchmarks/repro_capsules/lfm2_m1pro_20260131.json`](benchmarks/repro_capsules/lfm2_m1pro_20260131.json) |
| LFM2-24B-A2B-4bit | M4 Max 36 GB | 152.0 tok/s -> 161.1 tok/s | **+6.0%** | token-identical (500 tok) | [`benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json`](benchmarks/repro_capsules/lfm2_24b_dsimd_gate_m4max_20260224.json) |
| GPT-OSS-20B-4bit | M4 Max 36 GB | 121.8 tok/s -> 122.9 tok/s | +1.0% | token-identical | — |

To print a report from a capsule:

```bash
python -m zmlx.bench.report benchmarks/repro_capsules/<capsule>.json
```

</details>

<details>
<summary>Benchmarks (custom MLX primitive — requires building mlx_local/)</summary>

Any GLM/Qwen3 improvements on custom MLX come from `gather_qmm_swiglu`, a **custom C++ Metal primitive we wrote** (~800 lines of C++/Metal). It fuses gate projection + up projection + SwiGLU activation for quantized MoE experts into a single GPU dispatch. This primitive is not part of released MLX — build it by applying the patch described in [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md).

ZMLX provides the model-side integration: auto-detecting MoE architectures, rewiring forward passes to use the fused primitive, and using native MLX combine ops on GLM/Qwen3 for fidelity and lower dispatch overhead.

**On stock MLX (released 0.30.4/0.30.5), ZMLX auto-skips these models** (0 modules patched, 0% change) to avoid regressions. `patch()` is always safe to call.

| Model | Recommended config | Overall decode gain vs unpatched baseline | Fidelity | Evidence |
|:--|:--|--:|:--|:--|
| GLM-4.7-Flash-4bit-mxfp4 | `glm_combine_fp32_no_fma` | `+6.2%` (200), `+6.7%` (1024), `~+6.4%` average | PASS | `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t200_r2_summary.json`, `benchmarks/repro_capsules/glm47_combo_v8_fp32nofmaonly_t1024_r2_summary.json`, `benchmarks/repro_capsules/benchmark_vs_baseline_followup_20260211.json` |

Qwen3-30B-A3B: no candidate is promoted yet; keep control baseline until a clear decode-positive variant is reproduced.

See [`docs/EXPERIMENTAL_MLX.md`](docs/EXPERIMENTAL_MLX.md) for build instructions. Repro capsules in `benchmarks/repro_capsules/`.

</details>

<details>
<summary>Model support summary</summary>

| Model | Stock MLX | + Custom primitive | What ZMLX does |
|:--|:--|:--|:--|
| LFM2-8B-A1B | **+12% decode** | same | Fused MoE gating + combine + SwiGLU activation |
| LFM2-24B-A2B | **+6-7% decode** | same | D-SIMD fused gating kernel (64 experts, K=4) |
| Qwen3.5-35B-A3B | **~+2% decode** | same | Fused MoE dispatch (256 experts, K=8, hybrid DeltaNet+Attention) |
| GLM-4.7-Flash | 0% (auto-skipped) | **~+6% decode** | ZMLX patching + custom `gather_qmm_swiglu` primitive |
| Qwen3-30B-A3B | 0% (auto-skipped) | speedup | ZMLX patching + custom `gather_qmm_swiglu` primitive |
| GPT-OSS-20B | fused SwiGLU activation | same | ZMLX Metal kernel: fused SwiGLU activation |
| Other models | safe no-op | same | `patch()` returns unchanged if no patterns match |

All results are token-identical under greedy decoding. Verify on your hardware with `python -m zmlx.validate <model>`.

Patching controls:

```python
import mlx.core as mx
from zmlx.patch import patch, smart_patch

patch(model)                      # inference defaults (auto-skips unsafe patterns)
patch(model, patterns=["moe_mlp"])  # override safety; validate first

# Auto-benchmark: apply only patterns that actually help on your sample
sample = mx.array([tokenizer.encode("Hello")])
model = smart_patch(model, sample)
```

</details>

<details>
<summary>How patching works (MoE decode)</summary>

MoE decode is often dominated by Metal kernel dispatch overhead (many small ops per token).

ZMLX targets the multi-op sequences that show up during decode:

- **Gating:** top-k softmax selection fused into one kernel (`topk_gating_softmax`).
- **Combine:** weight-and-reduce across experts fused into one kernel (`moe_combine`).
- **Expert SwiGLU (when available):** gate+up projection+SwiGLU fused into one dispatch via custom `gather_qmm_swiglu` primitive.
- **Guards:** fused paths only activate at small sequence lengths (decode), keeping prefill throughput neutral.

Deeper dives:

- Walkthrough: [`docs/TOUR.md`](docs/TOUR.md)
- Design notes: [`docs/ARCHITECTURE.md`](docs/ARCHITECTURE.md)

</details>

<details>
<summary>Kernel authoring (very short example)</summary>

ZMLX can compile small Python expressions into Metal kernels via MLX's `mx.fast.metal_kernel`:

```python
from zmlx.api import elementwise
import mlx.core as mx

mish = elementwise("x * tanh(log(1 + exp(x)))", name="mish")
y = mish(mx.random.normal((1024,)))
mx.eval(y)
```

Next steps:

- 5-minute tutorial: [`docs/QUICKSTART.md`](docs/QUICKSTART.md)
- Recipes: [`docs/COOKBOOK.md`](docs/COOKBOOK.md)
- Catalog: [`docs/KERNELS.md`](docs/KERNELS.md)

</details>

<details>
<summary>Troubleshooting</summary>

| Symptom | Fix |
|:--|:--|
| `ModuleNotFoundError: No module named 'mlx'` | Requires Apple Silicon macOS. ZMLX does not support Intel Macs or Linux. |
| `ModuleNotFoundError: No module named 'mlx_lm'` | Install with `pip install "zmlx[lm]"` for model patching examples. |
| Model downloads fill disk | Set `HF_HOME` to a larger drive before running. |
| `patch()` shows 0 modules patched | The model may not match any patterns, or ZMLX auto-skipped them for safety. Run `python -m zmlx.validate <model>` to verify. |
| GLM/Qwen shows 0 modules patched | Expected on stock MLX. Requires building the custom `gather_qmm_swiglu` primitive in `mlx_local/` (see [docs](docs/EXPERIMENTAL_MLX.md)). |

</details>

<details>
<summary>Precision note</summary>

Most kernels compute internally in **float32** regardless of input dtype. The exception is `moe_combine_exact`, which accumulates in the input dtype to match MLX's bfloat16 semantics. GLM and Qwen3 use native MLX ops for the combine step (`(y * scores[..., None]).sum(axis=-2)`) to match the original model code exactly and avoid custom-kernel dispatch overhead.

</details>

---

## Acknowledgments

Built on [MLX](https://github.com/ml-explore/mlx) by Apple machine learning research. If you use ZMLX in your work, please also cite MLX:

```bibtex
@software{mlx2023,
  author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
  title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
  url = {https://github.com/ml-explore},
  version = {0.0},
  year = {2023},
}
```

## License

MIT. See [`LICENSE`](LICENSE).
