Metadata-Version: 2.4
Name: pytorch-essentials
Version: 0.1.1
Summary: PyTorch training utilities
Requires-Python: >=3.8
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch>=2.0.0
Requires-Dist: torchvision>=0.15.0
Requires-Dist: tqdm>=4.65.0
Requires-Dist: scikit-learn>=1.3.0
Requires-Dist: matplotlib>=3.7.0
Requires-Dist: PyYAML>=6.0
Requires-Dist: wandb>=0.15.0
Dynamic: license-file

# PyTorch Essentials

A collection of useful utilities for PyTorch model training, evaluation, and experiment management.

## Features

- **Training Loop**: Ready-to-use training and evaluation loops with progress tracking
- **Early Stopping**: Prevent overfitting with configurable early stopping callback
- **Visualization**: Plot training curves and metrics
- **Model Management**: Save/load models with metadata and results
- **Configuration**: YAML-based configuration management
- **Utilities**: Device detection, seed setting, parameter counting
- **Wandb Integration**: Comprehensive logging of classification metrics including precision, recall, F1-score, and confusion matrices

## Installation

### From PyPI (recommended)

```bash
pip install pytorch-essentials
```

**With optional Weights & Biases support:**
```bash
pip install pytorch-essentials[wandb]
```

**For development:**
```bash
pip install pytorch-essentials[dev]
```

### From source

```bash
git clone https://github.com/yourusername/pytorch_essentials.git
cd pytorch_essentials
pip install -e .
```

**Requirements:**
- Python 3.8+
- PyTorch 2.0+
- See `pyproject.toml` for full list of dependencies

## Quick Start

### Basic Usage

```python
from pytorch_essentials import (
    train, get_device, EarlyStopping,
    plot_loss_curves, set_seeds
)

# Set reproducibility
set_seeds(42)

# Get device
device = get_device()

# Initialize model, optimizer, loss
model = YourModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

# Optional: Early stopping
early_stopping = EarlyStopping(patience=5, delta=0.001)

# Train
results = train(
    model=model,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=100,
    device=device,
    config=config,
    early_stopping=early_stopping
)

# Visualize results
fig = plot_loss_curves(results)
fig.savefig('training_curves.png')
```

### With Configuration File

Create a `config.yaml`:

```yaml
project_name: my_project
hyperparameters:
  learning_rate: 1e-3
  batch_size: 64
  epochs: 100
flags:
  use_subset: false
  save_model: true
  debug: false
  use_wandb: false
```

Then in your code:

```python
from pytorch_essentials import load_config, print_config

config = load_config('config.yaml')
print_config(config)
```

## Examples

Check the `examples/` directory for complete working examples:

- `examples/basic_training.py` - Complete MNIST training example
- `examples/config.yaml` - Example configuration file

Run the example:

```bash
cd examples
python basic_training.py
```

## API Reference

### Training Functions

#### `train(model, train_dataloader, val_dataloader, optimizer, loss_fn, num_epochs, device, config, early_stopping=None, scheduler=None)`

Main training loop with validation.

**Returns:** Dictionary with training history including:
- `train_loss`: List of training losses
- `train_acc`: List of training accuracies
- `test_loss`: List of validation losses
- `test_acc`: List of validation accuracies
- `best_epoch`: Best epoch number (if early stopping used)

#### `evaluate_model(model, test_dataloader, loss_fn, device, class_names, log_to_wandb=False)`

Evaluate model on test set with confusion matrix and metrics.

**Args:**
- `log_to_wandb`: If True, log all metrics to Weights & Biases including:
  - Test loss and accuracy
  - Precision, recall, F1-score (macro and weighted)
  - Per-class metrics
  - Confusion matrix visualization

### Callbacks

#### `EarlyStopping(patience=5, delta=0.01, verbose=True)`

Early stopping to prevent overfitting.

**Args:**
- `patience`: Number of epochs to wait for improvement
- `delta`: Minimum change to qualify as improvement
- `verbose`: Print status messages

### Utilities

#### `get_device()`

Return best available device (CUDA > MPS > CPU).

#### `save_model(model, fig, results, save_path)`

Save model weights, plots, and training results.

#### `count_parameters(model)`

Count and display model parameters.

#### `set_seeds(seed=42)`

Set random seeds for reproducibility.

#### `load_config(config_path)`

Load YAML configuration file.

### Visualization

#### `plot_loss_curves(results)`

Plot training and validation loss/accuracy curves.

**Returns:** Matplotlib figure

#### `print_train_time(start, end, device=None)`

Print elapsed training time.

## Project Structure

```
pytorch_essentials/
├── pytorch_essentials/         # Main package
│   ├── __init__.py            # Package exports
│   ├── engine.py              # Training/evaluation loops
│   ├── callbacks.py           # Early stopping callback
│   ├── utils.py               # Utility functions
│   └── visualization.py       # Plotting functions
├── examples/                   # Example scripts
│   ├── basic_training.py      # MNIST example
│   └── config.yaml            # Example config
├── requirements.txt           # Dependencies
├── README.md                  # This file
└── LICENSE                    # License file
```
