Metadata-Version: 2.4
Name: parametrix
Version: 0.1.5
Summary: Flax-like computed parameters for bare JAX
Author: Gabriel S. Gerlero
Author-email: Gabriel S. Gerlero <ggerlero@cimec.unl.edu.ar>
License-Expression: Apache-2.0
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Science/Research
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: Topic :: Software Development :: Libraries
Classifier: Typing :: Typed
Classifier: Operating System :: OS Independent
Requires-Dist: equinox>=0.12.1,<0.14
Requires-Dist: jax>=0.5,<0.9
Requires-Dist: numpy>=1,<3
Requires-Dist: typing-extensions>=4,<5 ; python_full_version < '3.11'
Requires-Python: >=3.10
Project-URL: Documentation, https://parametrix.readthedocs.io
Project-URL: Homepage, https://github.com/gerlero/parametrix
Project-URL: Repository, https://github.com/gerlero/parametrix
Description-Content-Type: text/markdown

# <div align="center">[<img src="https://raw.githubusercontent.com/gerlero/parametrix/main/logo.png" alt="Parametrix logo" width=250></img>](https://github.com/gerlero/parametrix/)</div>

**[`flax.nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)-like computed parameters for bare [JAX](https://github.com/jax-ml/jax) (and [Equinox](https://github.com/patrick-kidger/equinox)).**

[![Documentation](https://img.shields.io/readthedocs/parametrix)](https://parametrix.readthedocs.io/)
[![CI](https://github.com/gerlero/parametrix/actions/workflows/ci.yml/badge.svg)](https://github.com/gerlero/parametrix/actions/workflows/ci.yml)
[![Codecov](https://codecov.io/gh/gerlero/parametrix/branch/main/graph/badge.svg)](https://codecov.io/gh/gerlero/parametrix)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
[![ty](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ty/main/assets/badge/v0.json)](https://github.com/astral-sh/ty)
[![uv](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json)](https://github.com/astral-sh/uv)
[![Publish](https://github.com/gerlero/parametrix/actions/workflows/pypi-publish.yml/badge.svg)](https://github.com/gerlero/parametrix/actions/workflows/pypi-publish.yml)
[![PyPI](https://img.shields.io/pypi/v/parametrix)](https://pypi.org/project/parametrix/)
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/parametrix)](https://pypi.org/project/parametrix/)


## Installation

```bash
pip install parametrix
```

## Example

The following example shows how to use [`Param`](https://parametrix.readthedocs.io) as a base class for a parameter class that enforces positivity:

```python
import jax.numpy as jnp
from parametrix import Param

class PositiveOnlyParam(Param):
    def __init__(self, value):
        super().__init__(jnp.log(value))

    @property
    def value(self):
        return jnp.exp(self.raw_value)
```

The backing values of `Param`s are always stored as `jax.Array`s, meaning that they will automatically be picked up as learnable parameters by libraries like Equinox.

`Param` objects also behave like numeric types, so that they are able to be used within models and any other functions without having to make any changes to the code.

## Documentation

API documentation is available at [Read the Docs](https://parametrix.readthedocs.io/).
