Metadata-Version: 2.3
Name: easydel
Version: 0.1.5.dev8
Summary: Accelerate, Optimize performance with streamlined training and serving options with JAX.
License: Apache-2.0
Keywords: Deep Learning,Machine Learning,JAX,CUDA,XLA,Triton,Pallas
Author: Erfan Zare Chavoshi
Author-email: Erfanzare810@gmail.com
Requires-Python: >=3.10,<3.14
Classifier: Development Status :: 3 - Alpha
Classifier: Intended Audience :: Developers
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Provides-Extra: all
Provides-Extra: tf
Requires-Dist: datasets (>=3.6.0)
Requires-Dist: eformer (==0.0.40)
Requires-Dist: einops (>=0.8.0,<0.9.0)
Requires-Dist: fastapi (>=0.115.2,<0.116.0)
Requires-Dist: flax (==0.10.6)
Requires-Dist: gcsfs
Requires-Dist: jax (>=0.6.0)
Requires-Dist: jaxlib (>=0.6.0)
Requires-Dist: jaxtyping (>=0.3.2,<0.4.0)
Requires-Dist: jinja2 (>=3.1.5)
Requires-Dist: optax (>=0.2.2,<0.3.0)
Requires-Dist: tensorboard (>=2.19.0) ; extra == "tf" or extra == "all"
Requires-Dist: tensorflow (>=2.19.0) ; extra == "tf" or extra == "all"
Requires-Dist: tensorflow-datasets (>=4.9.6) ; extra == "tf" or extra == "all"
Requires-Dist: tqdm
Requires-Dist: transformers (>=4.51.0)
Requires-Dist: triton (>=3.2.0,<3.3.0)
Requires-Dist: uvicorn (>=0.32.0,<0.33.0)
Requires-Dist: uvloop (==0.21.0)
Requires-Dist: wandb (>=0.18.5)
Requires-Dist: zstandard (>=0.23.0)
Project-URL: Documentation, https://easydel.readthedocs.io/en/latest/
Project-URL: Homepage, https://github.com/erfanzar/EasyDeL
Project-URL: Repository, https://github.com/erfanzar/EasyDeL
Description-Content-Type: text/markdown

<div align="center">
 <div style="margin-bottom: 50px;">
  <a href="">
  <img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/easydel-logo-with-text.png" height="80">
  </a>
 </div>
 <div>
 <a href="https://discord.gg/FCAMNqnGtt"> <img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/discord-button.png" height="48"></a>&nbsp;&nbsp;
 <a href="https://easydel.readthedocs.io/en/latest/"><img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/documentation-button.png" height="48"></a>&nbsp;&nbsp;
 <a href="https://easydel.readthedocs.io/en/latest/install.html"><img src="https://raw.githubusercontent.com/erfanzar/easydel/main/images/quick-start-button.png" height="48"></a>
 </div>
</div>

-----

EasyDeL is an open-source framework designed to enhance and streamline the training process of machine learning models, with a primary focus on Jax/Flax. Built on modern Flax NNX, it provides convenient and effective solutions for training and serving Flax/Jax models on TPU/GPU at scale.

## Key Features

- **Modern Architecture**: Built on Flax NNX for better integration, modularity, and performance
- **Diverse Model Support**: Seamless support for Transformers, Mamba, RWKV, Vision Models and more
- **Advanced Trainers**: Specialized trainers with unified base class for consistent behavior:
  - SFTTrainer for supervised fine-tuning
  - DPOTrainer for direct preference optimization
  - ORPOTrainer for offline reinforcement learning
  - More Trainers Like image-text-to-image and others are also supported in main trainer

- **Vision Model Support**: Comprehensive support for:
  - Vision-to-Vision tasks
  - Image-Text-to-Image generation
  - Image-to-Text processing
- **Production-Ready Serving**:
  - `vInference` engine for efficient LLM inference
  - `vInferenceApiServer` for OpenAI-compatible API endpoints
- **Performance Optimization**:
  - Integration with multiple attention mechanisms
  - Advanced quantization support including NF4, A8BIT, A8Q, and A4Q
  - Platform-specific optimizations (TRITON, XLA, Pallas)

### Fully Customizable and Hackable 🛠️

EasyDeL's architecture is designed for maximum flexibility:

- **Custom Module System**: Built on Flax NNX, allowing easy creation and integration of custom modules
- **Transparent Architecture**: Every component is open for inspection and modification
- **Dynamic Configuration**: Easily customize model architecture, training pipeline, and inference settings
- **Platform Flexibility**: Choose between different platforms (TRITON, XLA, Pallas) for optimal performance

### Optimizations

#### Attention Mechanisms

EasyDeL offers various attention implementations optimized for different use cases:

```python
import easydel as ed

# Configure attention mechanism in model config
config = ed.EasyDeLBaseConfigDict(attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2,) # Choose mechanism
```

Available mechanisms:

- **FLASH_ATTN2**: Optimized Flash Attention 2 implementation
- **CUDA_FLASH_ATTN2**: CUDA-specific Flash Attention 2 variant
- **RING**: Memory-efficient ring attention for distributed computing
- **VANILLA**: Standard attention implementation
- **SPLASH**: TPU-optimized splash attention
- **CUDNN**: cuDNN-accelerated attention
- **BLOCKWISE**: Memory-efficient blockwise attention
- **SDPA**: Scaled dot-product attention

Example of configuring a model with specific platform and attention:

```python
import easydel as ed
import jax.numpy as jnp

# Complete configuration example
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-instruct",
    platform=ed.EasyDeLPlatforms.TRITON,
    config_kwargs=ed.EasyDeLBaseConfigDict(
        attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2,
        attn_dtype=jnp.float16,
        gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
    ),
    dtype=jnp.float16,
    param_dtype=jnp.float16,
    auto_shard_model=True,
    sharding_axis_dims=(1, 1, 1, -1, 1) # Fully Tensor Parallel
)
```

## Quick Start

### Installation

```bash
pip install easydel
```

### Basic Inference Example

```python
import easydel as ed
from transformers import AutoTokenizer
import jax.numpy as jnp

# Initialize model
model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-instruct",
    dtype=jnp.float16,
    platform=ed.EasyDeLPlatforms.TRITON,
    auto_shard_model=True
)

# Setup tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-instruct")
tokenizer.pad_token_id = tokenizer.eos_token_id

# Create inference engine
sampling_params=ed.SamplingParams(
    max_tokens=1024,
    temperature=0.8,
    top_p=0.95,
    top_k=10,
)
inference = ed.vInference(
    model=model,
    tokenizer=tokenizer,
    generation_config=ed.vInferenceConfig(
        max_new_tokens=1024,
        sampling_params=sampling_params,
        streaming_chunks=32,
    )
)

# Create API server (OpenAI compatible)
api_server = ed.vInferenceApiServer({inference.inference_name: inference})
api_server.fire()
```

### DPOTraining Example

`DPOTrainer` implements Direct Preference Optimization (DPO) - a state-of-the-art algorithm used for preference-based model fine-tuning. DPO is widely used across major language models like Llama 3 to align them with human preferences. Here's how to use DPOTrainer in EasyDeL:

```python
import easydel as ed
from transformers import AutoTokenizer
from datasets import load_dataset


ed.DPOTrainer(
    model=ed.AutoEasyDeLModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct"),
    arguments=ed.DPOConfig(
        max_completion_length=128,
        max_prompt_length=128,
        max_length=256,
        save_steps=100,
    ),
    train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
    processing_class=AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct"),
).train()

```

### Building Custom Modules 🛠️

The `EasyDeLBaseModule` is the core foundation for creating custom modules and models in EasyDeL. It provides a powerful set of built-in features for configuration, data handling, parameter management, and more. It inherits from `flax.nnx.Module`, `BaseModuleProtocol`, `EasyBridgeMixin`, and `EasyGenerationMixin`.

```python
import easydel as ed
import jax
import jax.numpy as jnp
import typing as tp
from flax import nnx as nn

# Example Custom Module
class MyCustomModule(ed.EasyDeLBaseModule):
    def __init__(
        self,
        config,
        dtype: jnp.dtype = jnp.float32,
        param_dtype: jnp.dtype = jnp.float32,
        precision: jax.lax.PrecisionLike = None,
        other_parameters=...,
        other_parameters_1=...,
        *,
        rngs: nn.Rngs,
    ):
        super().__init__(
        config=config,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        rngs=rngs,
        )
        self.other_parameters = other_parameters

    def __call__(self, x):
        # Custom implementation
        return x


# Example Model using the Custom Module
class MyModel(ed.EasyDeLBaseModule):
    config_class = ed.EasyDeLBaseConfig
    def __init__(
        self,
        config,
        dtype: jnp.dtype = jnp.float32,
        param_dtype: jnp.dtype = jnp.float32,
        precision: jax.lax.PrecisionLike = None,
        other_parameters=...,
        other_parameters_1=...,
        other_parameters_2=...,
        *,
        rngs: nn.Rngs,
    ):
        super().__init__(
        config=config,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        rngs=rngs,
        )
        self.custom_module = MyCustomModule(
        config=config,
        dtype=dtype,
        param_dtype=param_dtype,
        precision=precision,
        other_parameters=other_parameters,
        other_parameters_1=other_parameters_1,
        rngs=rngs,
        )

    def __call__(self, x):
        return self.custom_module(x)
```

### Key Features at a Glance

- **Core Class:** Foundation for building EasyDeL modules, inheriting from `flax.nnx.Module`, `BaseModuleProtocol`, `EasyBridgeMixin` and `EasyGenerationMixin`.
- **Configuration:**
  - Uses a `config` attribute of type `EasyDeLBaseConfig` for model settings.
  - Provides `mesh` property for retrieving mesh from the config.
  - Includes class attributes for model metadata: `config_class`, `base_model_prefix`, `_model_task`, and `_model_type`.
- **Data Handling:**
  - Controls data types (`dtype`, `param_dtype`) and numerical precision (`precision`).
  - Supports quick precision switching with `half()` (float16) and `float()` (float32).
  - `module_dtype` property returns the parameters dtype
- **Randomness:** Manages random number generation via `nn.Rngs`.
- **Parameter Shape:** Inspects the shape of all model parameters with `graphtree_params_shape` property.
- **Causal Mask & Frequencies:** Accesses `causal_mask` and `frequencies` from the config for sequence modeling.
- **Static Arguments:**
  - `static_arguments` property with ability to define static arguments via `get_static_arguments` method.
- **Dynamic Loss:** Dynamically loads loss functions via `loss_function` based on name or `loss_type`.
- **Sharding/Gathering:** Offers `shard_model` and `gather_model` for parameter distribution across devices and has access to current sharding (`_shard_fns`) and gathering (`_gather_fns`) functions.
- **Quantization:** Applies post-training quantization to linear layers using `quantize`.
- **State Conversion:** Converts the module to a state object via `to_state`.
- **Input Preparation**: Provides a method for pre-processing inputs with `prepare_inputs_for_call`.
- **Generic Loss Calculation:** Computes loss via `compute_loss`.
- **Easy Bridge:** Enables saving/loading models using `save_pretrained` and pushing to the Hub with `push_to_hub`, while also offering methods for loading from checkpoints (`_load_model_weights`),  from pretrained models(`from_pretrained`) and from torch (`_from_torch_pretrained`).
- **Easy Generation:** Facilitates text generation via `generate` with multiple decoding algorithms, while using methods for caching, decoder initialization, beam search configurations, logits processing and other aspects of generation.

**Benefits:**

- **Modularity:** Encourages reusable, composable model components.
- **Flexibility:** Adapts to diverse model architectures.
- **Efficiency:** Integrates with EasyDeL's distributed training and optimization features.
- **Ease of Use:** Simplifies saving/loading and text generation processes.

### Training Example

```python
import easydel as ed
import jax.numpy as jnp

# Create model
model = ed.LlamaForCausalLM(
    config=model_config,
    dtype=jnp.float32,
    param_dtype=jnp.float32
)

# Configure training
training_args = ed.TrainingArguments(
    save_directory="checkpoints",
    model_name="my_model",
    num_train_epochs=3,
    total_batch_size=8,
    learning_rate=3e-4,
    optimizer=ed.EasyDeLOptimizers.ADAMW,
    scheduler=ed.EasyDeLSchedulers.COSINE
)

# Initialize trainer
trainer = ed.Trainer(
    arguments=training_args,
    model=model,
    dataset_train=train_dataset,
    dataset_eval=eval_dataset
)

# Start training
trainer.train()
```

## Latest Updates 🔥

- **Modernized Architecture**:
  - Migrated to Flax NNX for improved modularity and performance
  - Custom module system for better extensibility
  - Unified base trainer for consistent behavior
- **Enhanced Performance**:
  - Faster training and inference
  - Better large-scale handling
  - Improved memory management
- **Vision Model Support**:
  - Added comprehensive vision model capabilities
  - Support for multimodal tasks
- **Major Improvements**:
  - Fixed core issues from easydel-linen
  - Enhanced stability and reliability
  - Better error handling and debugging

## Documentation 💫

For comprehensive documentation, tutorials, and API reference, visit [EasyDeL Documentation](https://easydel.readthedocs.io/en/latest/).

## Contributing

We welcome contributions! Please see our contributing guidelines in the repository.

## License 📜

EasyDeL is released under the Apache v2 license. See the LICENSE file for details.

## Citation

```bibtex
@misc{Zare Chavoshi_2023,
    title={EasyDeL: An open-source library for enhancing and streamlining the training process of machine learning models},
    url={https://github.com/erfanzar/EasyDeL},
    author={Zare Chavoshi, Erfan},
    year={2023}
}
```

## Contact

For questions or comments about EasyDeL, contact: _erfanzare810@gmail.com_

