Metadata-Version: 2.4
Name: mgcv-rust
Version: 0.6.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Rust
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Topic :: Scientific/Engineering
Requires-Dist: numpy>=1.20.0
License-File: LICENSE
Summary: Rust implementation of Generalized Additive Models with Python bindings
Keywords: statistics,machine-learning,gam,regression,splines
Author-email: Aleksander Jaworski <ale.jaworski@gmail.com>
Requires-Python: >=3.8
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
Project-URL: Homepage, https://github.com/AlekJaworski/nn_exploring
Project-URL: Issues, https://github.com/AlekJaworski/nn_exploring/issues
Project-URL: Repository, https://github.com/AlekJaworski/nn_exploring

# mgcv_rust: Generalized Additive Models in Rust

A Rust implementation of Generalized Additive Models (GAMs) with automatic smoothing parameter selection using REML (Restricted Maximum Likelihood) and the PiRLS (Penalized Iteratively Reweighted Least Squares) algorithm, inspired by R's `mgcv` package.

## Features

- **Multiple Distribution Families**: Gaussian, Binomial, Poisson, and Gamma
- **Flexible Basis Functions**:
  - Cubic B-splines with natural boundary conditions
  - Thin plate splines for smooth multivariate regression
- **Automatic Smoothing**:
  - REML (Restricted Maximum Likelihood) criterion
  - GCV (Generalized Cross-Validation) criterion
- **PiRLS Algorithm**: Efficient fitting via Penalized Iteratively Reweighted Least Squares
- **Pure Rust**: No external BLAS/LAPACK dependencies
- **Test-Driven Development**: Comprehensive test suite with 20+ passing tests

## Installation

Add to your `Cargo.toml`:

```toml
[dependencies]
mgcv_rust = { path = "." }
ndarray = "0.16"
```

## Quick Start

### Python (Recommended)

```python
import numpy as np
from mgcv_rust import GAM

# Generate data: y = sin(2πx) + noise
X = np.random.uniform(0, 1, (500, 2))
y = np.sin(2 * np.pi * X[:, 0]) + 0.5 * (X[:, 1] - 0.5)**2

# Fit GAM with automatic smooth setup
gam = GAM()
result = gam.fit(X, y, k=[10, 15])  # That's it!

print(f"Lambda values: {result['lambda']}")
print(f"Deviance: {result['deviance']}")

# Make predictions
X_test = np.random.uniform(0, 1, (100, 2))
predictions = gam.predict(X_test)
```

**Performance**: 1.5x - 65x faster than R's mgcv (problem-dependent)

See `API_SIMPLIFICATION.md` for more details on the simplified Python API.

### Rust

```rust
use mgcv_rust::{GAM, Family, SmoothTerm, OptimizationMethod};
use ndarray::{Array1, Array2};

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Generate data: y = sin(2πx) + noise
    let n = 100;
    let x_data: Vec<f64> = (0..n).map(|i| i as f64 / n as f64).collect();
    let y_data: Vec<f64> = x_data
        .iter()
        .map(|&xi| (2.0 * std::f64::consts::PI * xi).sin() + noise())
        .collect();

    let x = Array1::from_vec(x_data);
    let y = Array1::from_vec(y_data);
    let x_matrix = x.into_shape((n, 1))?;

    // Create GAM with cubic spline smooth
    let mut gam = GAM::new(Family::Gaussian);
    let smooth = SmoothTerm::cubic_spline("x".to_string(), 20, 0.0, 1.0)?;
    gam.add_smooth(smooth);

    // Fit with REML smoothing parameter selection
    gam.fit(
        &x_matrix,
        &y,
        OptimizationMethod::REML,
        5,    // max outer iterations
        50,   // max inner iterations (PiRLS)
        1e-4  // convergence tolerance
    )?;

    // Make predictions
    let predictions = gam.predict(&x_test)?;

    Ok(())
}
```

## Architecture

### Core Components

1. **`basis.rs`**: Basis function implementations
   - `CubicSpline`: Cubic B-spline basis with configurable knots
   - `ThinPlateSpline`: Radial basis functions for smooth regression

2. **`penalty.rs`**: Penalty matrix construction
   - Second derivative penalties for smoothness
   - Supports multiple penalty types per basis

3. **`pirls.rs`**: Penalized IRLS fitting algorithm
   - Implements PiRLS for GLMs with penalties
   - Supports all standard GLM families
   - Automatic weight computation and convergence checking

4. **`reml.rs`**: Smoothing parameter selection
   - REML criterion for optimal smoothing
   - GCV criterion as alternative
   - Log-determinant computations

5. **`smooth.rs`**: Smoothing parameter optimization
   - Coordinate descent optimization
   - Grid search for initialization
   - Works in log-space for numerical stability

6. **`gam.rs`**: Main GAM model interface
   - Combines all components
   - Handles multiple smooth terms
   - Outer loop for lambda optimization

7. **`linalg.rs`**: Linear algebra operations
   - Gaussian elimination with partial pivoting
   - Matrix inversion via Gauss-Jordan
   - Determinant computation via LU decomposition

## Mathematical Background

### GAM Model

```
g(E[Y]) = β₀ + f₁(x₁) + f₂(x₂) + ... + fₚ(xₚ)
```

Where:
- `g()` is the link function
- `fᵢ()` are smooth functions represented by basis expansions
- Each smooth is penalized by `λᵢ ∫ (f''ᵢ(x))² dx`

### PiRLS Algorithm

1. Initialize: η = g(y)
2. Until convergence:
   - Compute μ = g⁻¹(η)
   - Compute weights: w = (g'(μ))² / V(μ)
   - Compute working response: z = η + (y - μ) / g'(μ)
   - Solve: β = (X'WX + λS)⁻¹ X'Wz
   - Update: η = Xβ

### REML Criterion

```
REML(λ) = n·log(RSS) + log|X'WX + λS| - log|S|
```

Minimized with respect to λ to select optimal smoothing parameters.

## Examples

See `examples/simple_gam.rs` for a complete working example:

```bash
cargo run --example simple_gam --release
```

## Project Structure

```
├── src/                    # Core Rust library code
├── examples/               # Rust usage examples
├── benches/               # Rust benchmarks
├── tests/                 # Rust tests
├── scripts/               # Testing and benchmarking scripts
│   ├── python/            # Python scripts
│   │   ├── tests/         # Python test scripts
│   │   └── benchmarks/    # Python benchmark scripts
│   └── r/                 # R scripts
│       ├── tests/         # R test scripts
│       └── benchmarks/    # R benchmark scripts
├── docs/                  # Documentation and analysis
└── test_data/            # Test data and results
```

## Testing

Run the Rust test suite:

```bash
cargo test
```

All 20 tests should pass, covering:
- Basis function evaluation
- Penalty matrix construction
- Linear algebra operations
- REML/GCV criteria
- PiRLS convergence
- Full GAM fitting pipeline

Additional tests and benchmarks are available in the `scripts/` directory.

## Implementation Notes

- **TDD Approach**: Every feature was implemented with tests first
- **No External Dependencies**: Custom linear algebra to avoid BLAS/LAPACK issues
- **Numerical Stability**: Operations performed in log-space where appropriate
- **Extensible Design**: Easy to add new basis types, families, or criteria

## Limitations & Future Work

- Smoothing parameter optimization could be improved with better algorithms (e.g., Newton-Raphson)
- Eigendecomposition for handling penalty null spaces more rigorously
- Confidence intervals and standard errors
- Model diagnostics and residual analysis
- Tensor product smooths for multivariate terms
- Parallel processing for large datasets

## References

- Wood, S.N. (2017). Generalized Additive Models: An Introduction with R (2nd ed.). Chapman and Hall/CRC.
- Wood, S.N. (2011). Fast stable restricted maximum likelihood and marginal likelihood estimation of semiparametric generalized linear models. JRSS-B, 73(1), 3-36.

## License

MIT License - see LICENSE file for details

## Author

Implemented as a Rust port of R's mgcv package core functionality.

## Update: REML Implementation Fixed! ✅

**You were absolutely right** - the REML implementation had bugs that caused it to always select λ ≈ 0.

### What Was Wrong

1. **Singular Penalty Handling**: REML was incorrectly handling rank-deficient penalty matrices, setting `log|S| = 0` which broke the criterion
2. **Lambda Passing**: Optimization was passing `λ = 1.0` with pre-multiplied penalties, confusing the `rank(S)*log(λ)` term
3. **Insufficient Data**: Examples used n=30 with p=15 (ratio 2:1), which is too small for REML/GCV

### What Was Fixed

1. **REML Criterion**: Now correctly uses `log|λS| = rank(S)*log(λ) + constant`
2. **Optimization**: Passes actual λ values to criterion functions
3. **Data Size**: Increased to n=300 for proper n/p ratio (20:1)
4. **REML Search**: Uses fine grid search (gradient descent had issues)

### Current Performance (n=300)

```
GCV:  λ = 0.067, Test RMSE = 0.480  ✅
REML: λ = 0.058, Test RMSE = 0.480  ✅
```

Both methods now select nearly optimal smoothing parameters!

See `IMPLEMENTATION_SUMMARY.md` for complete details.

