Metadata-Version: 2.1
Name: output-shape
Version: 0.0.2
Summary: A very lightweight and minimalistic output shape examiner of layers and models.
Home-page: https://github.com/avocardio/output-shape
Author: avocardio
License: MIT
Platform: UNKNOWN
Classifier: Development Status :: 4 - Beta
Classifier: Intended Audience :: Developers
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Classifier: License :: OSI Approved :: MIT License
Classifier: Programming Language :: Python :: 3.6
Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: torch (>=1.6)

# output-shape 

[![PyPI version](https://badge.fury.io/py/output-shape.svg)](https://badge.fury.io/py/output-shape)

A very lightweight and minimalistic output shape examiner of layers and models.

** Currently working for PyTorch models only. Keras / Jax soon! **

# Installation
```bash
pip install output-shape
```

# Usage

Decorate the forward method with `@output_shape`, then use either option:

```python
import torch
from output_shape import output_shape, debug_shapes

class Model(torch.nn.Module):
    def __init__(self, debug=False):
        super().__init__()
        self.debug = debug
        ...

    @output_shape
    def forward(self, x):
        ...

# Option 1: Context manager
model = Model()
with debug_shapes():
    model(torch.randn(2, 1, 128, 128))

# Option 2: Instance flag
model = Model(debug=True)
model(torch.randn(2, 1, 128, 128))
```

```python
Input                           torch.Size([2, 1, 128, 128])
Conv2d                          torch.Size([2, 768, 8, 8])
PatchEmbed                      torch.Size([2, 64, 768])
LayerNorm                       torch.Size([2, 13, 768])
Linear                          torch.Size([2, 13, 2304])
Linear                          torch.Size([2, 13, 768])
Dropout                         torch.Size([2, 13, 768])
Attention                       torch.Size([2, 13, 768])
PreNorm                         torch.Size([2, 13, 768])
LayerNorm                       torch.Size([2, 13, 768])
Linear                          torch.Size([2, 13, 3072])
GELU                            torch.Size([2, 13, 3072])
Dropout                         torch.Size([2, 13, 3072])
Linear                          torch.Size([2, 13, 768])
Dropout                         torch.Size([2, 13, 768])
FeedForward                     torch.Size([2, 13, 768])
PreNorm                         torch.Size([2, 13, 768])
Transformer                     torch.Size([2, 13, 768])
LayerNorm                       torch.Size([2, 13, 768])
Linear                       

