import numpy as np
from numpy.typing import NDArray
import atdata
@atdata.packable
class RichSample:
"""Multi-modal sample with image, caption, and classification."""
image: NDArray
caption: str
label: str
split: str
confidence: floatLens Transformations
Lenses let you reinterpret existing data under a different type schema—without duplicating storage. This example defines a rich sample type, creates two alternative views, and shows how the global LensNetwork routes as_type() calls automatically.
1 — Define the source schema
Imagine a multi-modal dataset that stores images, text captions, and classification metadata together.
2 — Define view schemas
Different consumers need different slices of this data. An image classifier only needs the image and label; a text model only needs the caption.
@atdata.packable
class ClassificationView:
"""Minimal view for image classification tasks."""
image: NDArray
label: str
@atdata.packable
class CaptionView:
"""Text-only view for captioning or NLP tasks."""
caption: str
label: str
confidence: float3 — Register lenses
A lens is a getter (source -> view) and an optional putter (view, source -> source). The @atdata.lens decorator registers the transformation globally.
@atdata.lens
def to_classification(src: RichSample) -> ClassificationView:
return ClassificationView(image=src.image, label=src.label)
@to_classification.putter
def to_classification_put(
view: ClassificationView, src: RichSample
) -> RichSample:
return RichSample(
image=view.image,
caption=src.caption,
label=view.label,
split=src.split,
confidence=src.confidence,
)
@atdata.lens
def to_caption(src: RichSample) -> CaptionView:
return CaptionView(
caption=src.caption,
label=src.label,
confidence=src.confidence,
)4 — Write a small dataset
import tempfile
from pathlib import Path
rng = np.random.default_rng(0)
labels = ["cat", "dog", "bird", "fish"]
samples = [
RichSample(
image=rng.integers(0, 255, (64, 64, 3), dtype=np.uint8),
caption=f"A photo of a {labels[i % 4]}",
label=labels[i % 4],
split="train",
confidence=round(float(rng.uniform(0.7, 1.0)), 3),
)
for i in range(200)
]
tmpdir = Path(tempfile.mkdtemp(prefix="atdata_lens_"))
ds = atdata.write_samples(samples, tmpdir / "data.tar", maxcount=100)
print(f"Wrote {len(samples)} samples across {len(ds.list_shards())} shards")# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_lens_kxp0682w/data-000000.tar 0 0.0 GB 0
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_lens_kxp0682w/data-000001.tar 100 0.0 GB 100
Wrote 200 samples across 2 shards
5 — View through lenses
Dataset.as_type() looks up the registered lens automatically via the LensNetwork singleton.
# Classification view -- only image + label
cls_ds = ds.as_type(ClassificationView)
for batch in cls_ds.ordered(batch_size=16):
print(f"image shape : {batch.image.shape}") # (16, 64, 64, 3)
print(f"labels : {batch.label[:4]}...")
breakimage shape : (16, 64, 64, 3)
labels : ['cat', 'dog', 'bird', 'fish']...
# Caption view -- only text fields
cap_ds = ds.as_type(CaptionView)
for batch in cap_ds.ordered(batch_size=16):
print(f"captions : {batch.caption[:2]}...")
print(f"confidence : {batch.confidence[:4]}...")
breakcaptions : ['A photo of a cat', 'A photo of a dog']...
confidence : [0.714, 0.82, 0.896, 0.725]...
The underlying tar files are read once; the lens getter runs per-sample during iteration. No extra storage, no ETL step.
6 — Round-trip with putter
A lens with a putter supports the put direction: update the view, then propagate changes back to the source while preserving untouched fields.
original = samples[0]
print(f"Original label : {original.label}")
# Get the classification view
view = to_classification.get(original)
print(f"View label : {view.label}")
# Modify the view
corrected = ClassificationView(image=view.image, label="kitten")
# Put it back -- caption, split, confidence are preserved
updated = to_classification.put(corrected, original)
print(f"Updated label : {updated.label}")
print(f"Caption kept : {updated.caption}")Original label : cat
View label : cat
Updated label : kitten
Caption kept : A photo of a cat
7 — Inspect the LensNetwork
The global registry lets you see all registered transformations at a glance.
network = atdata.LensNetwork()
for (src, view), lens_obj in network._registry.items():
if src is RichSample:
print(f" {src.__name__} -> {view.__name__} via {lens_obj.__name__}") RichSample -> ClassificationView via to_classification
RichSample -> CaptionView via to_caption
8 — Clean up
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)Key takeaways
| Concept | API |
|---|---|
| Define a lens | @atdata.lens decorator |
| Add a putter | @my_lens.putter decorator |
| View a dataset | ds.as_type(ViewType) |
| Forward transform | lens.get(source) |
| Reverse transform | lens.put(view, source) |
| Inspect registry | atdata.LensNetwork()._registry |
Lenses compose naturally with batching and shuffling—as_type() returns a full Dataset that supports every iteration mode.