Metadata-Version: 2.3
Name: eformer
Version: 0.0.20
Summary: (EasyDel Former) is a utility library designed to simplify and enhance the development in JAX
License: Apache-2.0
Keywords: JAX,Deep Learning,Machine Learning,Flax,XLA,EasyDeL
Author: Erfan Zare Chavoshi
Author-email: Erfanzare810@gmail.com
Requires-Python: >=3.10,<4.0
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Dist: chex
Requires-Dist: jax (>=0.5.2)
Requires-Dist: mergedeep (>=1.3.4,<1.4.0)
Requires-Dist: ml-collections (==0.1.1)
Requires-Dist: optax (>=0.2.2)
Requires-Dist: plum-dispatch (==2.3.2)
Requires-Dist: pyyaml (>=6.0.2,<6.1.0)
Requires-Dist: ray[default] (>=2.44.1)
Requires-Dist: tblib (>=3.0.0,<3.1.0)
Project-URL: Documentation, https://erfanzar.github.io/eformer
Project-URL: Homepage, https://github.com/erfanzar/eformer
Project-URL: Repository, https://github.com/erfanzar/eformer
Description-Content-Type: text/markdown

# eformer (EasyDel Former)

[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Python](https://img.shields.io/badge/Python-3.8%2B-blue)](https://www.python.org/)
[![JAX](https://img.shields.io/badge/JAX-Compatible-brightgreen)](https://github.com/google/jax)

**eformer** (EasyDel Former) is a utility library designed to simplify and enhance the development of machine learning models using JAX. It provides a collection of tools for sharding, custom PyTrees, quantization, mixed precision training, and optimized operations, making it easier to build and scale models efficiently.

## Features

- **Mixed Precision Training (`mpric`)**: Advanced mixed precision utilities supporting float8, float16, and bfloat16 with dynamic loss scaling.
- **Sharding Utilities (`escale`)**: Tools for efficient sharding and distributed computation in JAX.
- **Custom PyTrees (`jaximus`)**: Enhanced utilities for creating custom PyTrees and `ArrayValue` objects, updated from Equinox.
- **Custom Calling (`callib`)**: A tool for custom function calls and direct integration with Triton kernels in JAX.
- **Optimizer Factory**: A flexible factory for creating and configuring optimizers like AdamW, Adafactor, Lion, and RMSProp.
- **Custom Operations and Kernels**:
  - Flash Attention 2 for GPUs/TPUs (via Triton and Pallas).
  - 8-bit and NF4 quantization for efficient model.
  - Many others to be added.
- **Quantization Support**: Tools for 8-bit and NF4 quantization, enabling memory-efficient model deployment.

## Installation

You can install `eformer` via pip:

```bash
pip install eformer
```

## Quick Start

### Mixed Precision Handler with mpric

```python
from eformer.mpric import PrecisionHandler

# Create a handler with float8 compute precision
handler = PrecisionHandler(
    policy="p=f32,c=f8_e4m3,o=f32",  # params in f32, compute in float8, output in f32
    use_dynamic_scale=True
)
```

### Customizing Arrays With ArrayValue

```python
import jax

from eformer.jaximus import ArrayValue, implicit
from eformer.ops.quantization.quantization_functions import (
    dequantize_row_q8_0,
    quantize_row_q8_0,
)

array = jax.random.normal(jax.random.key(0), (256, 64), "f2")


class Array8B(ArrayValue):
    scale: jax.Array
    weight: jax.Array

    def __init__(self, array: jax.Array):
        self.weight, self.scale = quantize_row_q8_0(array)

    def materialize(self):
        return dequantize_row_q8_0(self.weight, self.scale)


qarray = Array8B(array)


@jax.jit
@implicit
def sqrt(x):
    return jax.numpy.sqrt(x)


print(sqrt(qarray))
print(qarray)
```

### Optimizer Factory

```python
from eformer.optimizers import OptimizerFactory, SchedulerConfig, AdamWConfig

# Create an AdamW optimizer with a cosine scheduler
scheduler_config = SchedulerConfig(scheduler_type="cosine", learning_rate=1e-3, steps=1000)
optimizer, scheduler = OptimizerFactory.create("adamw", scheduler_config, AdamWConfig())
```

### Quantization

```python
from eformer.quantization import Array8B, ArrayNF4

# Quantize an array to 8-bit
qarray = Array8B(jax.random.normal(jax.random.key(0), (256, 64), "f2"))

# Quantize an array to NF4
n4array = ArrayNF4(jax.random.normal(jax.random.key(0), (256, 64), "f2"), 64)
```

### Advanced Mixed Precision Configuration

```python
from eformer.mpric import Policy, LossScaleConfig

# Create a custom precision policy
policy = Policy(
    param_dtype=jnp.float32,
    compute_dtype=jnp.bfloat16,
    output_dtype=jnp.float32
)

# Configure loss scaling
loss_config = LossScaleConfig(
    initial_scale=2**15,
    growth_interval=2000,
    scale_factor=2,
    min_scale=1.0
)

# Create handler with custom configuration
handler = PrecisionHandler(
    policy=policy,
    use_dynamic_scale=True,
    loss_scale_config=loss_config
)
```

## Contributing

We welcome contributions! Please read our [Contributing Guidelines](CONTRIBUTING.md) to get started.

## License

This project is licensed under the Apache License 2.0. See the [LICENSE](LICENSE) file for details.

