Metadata-Version: 2.4
Name: torch-diode-lib
Version: 0.1.1
Summary: A toolkit for matrix multiplication data collection, model training, and evaluation (library version without PyTorch Inductor auto-registration)
Author-email: Gabriel Ferns <gabeferns@meta.com>
License-Expression: MIT
Project-URL: Homepage, https://github.com/exclamaforte/diode
Project-URL: Repository, https://github.com/exclamaforte/diode.git
Project-URL: Documentation, https://exclamaforte.github.io/diode/
Project-URL: Bug Tracker, https://github.com/exclamaforte/diode/issues
Keywords: machine-learning,pytorch,torch,matrix-multiplication,heuristics,optimization,hardware,library
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries :: Python Modules
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.9.0.dev
Requires-Dist: msgpack>=1.0.0
Provides-Extra: lint
Requires-Dist: lintrunner>=0.12.0; extra == "lint"
Requires-Dist: matplotlib>=3.5.0; extra == "lint"
Requires-Dist: tqdm>=4.62.0; extra == "lint"
Requires-Dist: flake8==6.1.0; extra == "lint"
Requires-Dist: flake8-bugbear==23.3.23; extra == "lint"
Requires-Dist: flake8-comprehensions==3.15.0; extra == "lint"
Requires-Dist: flake8-executable==2.1.3; extra == "lint"
Requires-Dist: flake8-logging-format==0.9.0; extra == "lint"
Requires-Dist: flake8-pyi==23.3.1; extra == "lint"
Requires-Dist: flake8-simplify==0.19.3; extra == "lint"
Requires-Dist: mccabe==0.7.0; extra == "lint"
Requires-Dist: pycodestyle==2.11.1; extra == "lint"
Requires-Dist: pyflakes==3.1.0; extra == "lint"
Requires-Dist: mypy==1.16.0; extra == "lint"
Requires-Dist: types-requests==2.27.25; extra == "lint"
Requires-Dist: types-pyyaml==6.0.1; extra == "lint"
Requires-Dist: types-setuptools==79.0.0.20250422; extra == "lint"
Requires-Dist: ruff==0.12.2; extra == "lint"
Requires-Dist: black==23.12.1; extra == "lint"
Requires-Dist: usort==1.0.8.post1; extra == "lint"
Requires-Dist: isort==6.0.1; extra == "lint"
Requires-Dist: codespell[toml]==2.4.1; extra == "lint"
Requires-Dist: packaging==25.0; extra == "lint"
Requires-Dist: tomli==2.2.1; python_version < "3.11" and extra == "lint"
Dynamic: license-file

Warning: code is in pre-Alpha


<img width="718" height="571" alt="diode" src="https://github.com/user-attachments/assets/308cb05a-01d9-4fc4-9c03-7e13ade91475" />

# torch-diode
`torch-diode` is a library for programattically altering the performance-relevant decisions made by `torch.compile`. It makes it easy to gather data on the outcomes of decisions, and then train Machine Learning models on that data. It initially focuses on Matmul Kernel selection, but it will be expanded to other decisions in the future. [Documentation](https://exclamaforte.github.io/diode/)

## Target Audience:
- Developers looking to adapt the compilation of their model to their specific situation.
- Hardware Vendors looking to optimize `torch.compile` heuristics for their hardware.
- OSS Contributors looking to add support for less popular hardware.

## Features:
- Pre-Trained Models: Profit from community efforts to gather data and train models.
- Data collection: Gather data from torch external interfaces.
- Stable Type Definitions: storing data from the external interfaces.
- Model Training Code: Train ML models on the gathered data and contribute back to the `torch` community.


## Featured Models
- Matrix Multiplication Kernel Prediction: Predict the runtime of matrix multiplication kernels. The results of this model are enabled in `fast-autotune`.

## Option 1: Installation with Pre-Trained Models

If you want to get access to the pre-trained performance models, as well as the libraries, install `torch-diode`:
```
$ pip install torch-diode
```
And then import torch_diode in python:
```
import torch_diode
```

This import has several side-effects, each of which are dependent on the success of the previous step:
1. Attempt to import `torch`.
1. Register dummy models to the relevant `torch.compile` interfaces.
1. For each registration that is successful, it will load the actual model and register it.
1. Enable the configs in `torch.compile` that engage the models.
## Option 2: Installation without Pre-Trained Models

`diode` requires nightly pytorch, or pytorch `2.9` or later.

For developers who don't want these side effects, simply installing `torch-diode-lib` will get access to the library.

```
$ pip install torch-diode-lib
```

The import remains the same:
```
import torch_diode
```

### Option 3: Install from Source
```bash
git clone https://github.com/exclamaforte/diode.git
cd diode
pip install .
```

## Model Organization

### Directory Structure
Models are organized in a structured directory format:
```
trained_models/
├── <model_purpose>/
│   ├── <model_name>.pt
│   └── ...
└── <other_model_file>.pt
```

Example:
```
trained_models/
├── matmul_kernel_runtime_prediction/
│   ├── v1_model.pt
│   └── v2_model.pt
└── matmul_model_exhaustive.pt
```
## Get Started

[The main entry point is in workflows.](https://github.com/exclamaforte/diode/tree/main/workflows#readme)

### Package Variants
- **torch-diode**: Full package with auto-registration to PyTorch Inductor
- **torch-diode-lib**: Library-only version without auto-registration
