Metadata-Version: 2.4
Name: modelbatch
Version: 0.1.0
Summary: Train many independent PyTorch models simultaneously on a single GPU using vectorized operations
Keywords: pytorch,machine learning,vectorization,gpu,training
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.9
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: numpy>=1.20.0
Requires-Dist: tqdm>=4.64.0
Dynamic: license-file

# ModelBatch

**Train many independent PyTorch models simultaneously on a single GPU using vectorized operations.**

[![Python 3.9+](https://img.shields.io/badge/python-3.9+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch 2.0+](https://img.shields.io/badge/PyTorch-2.0+-orange.svg)](https://pytorch.org/)

## ⚠️ Current Status

**ModelBatch is still in active development. Core functionality is tested and working, but the API may be subject to change.**

## 🚀 Quick Start

### Installation

From PyPI:

```bash
# recommended
uv add modelbatch

# alternative
pip install modelbatch
```

From source:

```bash
uv sync --dev
uv pip install -e ".[dev]"
```

### Basic Example

```python
import torch
from modelbatch import ModelBatch

# Create multiple models
num_models = 4  # choose the number of models to batch
models = [SimpleNet() for _ in range(num_models)]

# Wrap with ModelBatch - that's it!
mb = ModelBatch(models, lr_list=[0.001] * num_models, optimizer_cls=torch.optim.Adam)

# Train normally (but many times faster!), batched across models
for batch in dataloader:
    mb.zero_grad()
    outputs = mb(batch)
    loss = mb.compute_loss(outputs, targets)  
    loss.backward()
    mb.step()
```

See [here](examples) for more examples.

## 📚 Documentation

See [docs](https://rock-z.github.io/ModelBatch/).

## 🛠️ Development

### Environment Setup

```bash
uv sync --dev
```

### Commands

```bash
# Tests (currently showing failures)
uv run -m pytest

# Linting  
uv run ruff check --fix . && uv run ruff format .

# Documentation
uv run mkdocs serve
```

## 📄 License

This project is licensed under the [MIT License](LICENSE).
