Typed Dataset Pipeline

Build a complete typed dataset from scratch: define, write, shard, and iterate

This example walks through the full lifecycle of a typed dataset. You will define a sample type with the @packable decorator, generate synthetic data, write it across multiple shards, and iterate over it with automatic batch aggregation.

1 — Define a sample type

The @packable decorator turns a plain class into a serializable dataclass with automatic msgpack encoding and transparent NDArray conversion.

import numpy as np
from numpy.typing import NDArray
import atdata


@atdata.packable
class SensorReading:
    """A single sensor observation with metadata."""

    waveform: NDArray  # 1-D time-series array
    sensor_id: str
    temperature: float
    anomaly: bool

Every field is strongly typed. NDArray fields are automatically compressed to bytes during serialization and restored on read—no manual conversion needed.

2 — Generate synthetic samples

rng = np.random.default_rng(42)

samples = [
    SensorReading(
        waveform=rng.standard_normal(512).astype(np.float32),
        sensor_id=f"sensor_{i % 8:02d}",
        temperature=20.0 + rng.normal(0, 3),
        anomaly=rng.random() < 0.05,
    )
    for i in range(2_000)
]

print(f"Created {len(samples)} samples")
print(f"Waveform shape: {samples[0].waveform.shape}")
Created 2000 samples
Waveform shape: (512,)

3 — Verify round-trip serialization

Before writing to disk, confirm that packing and unpacking preserves data.

original = samples[0]
packed = original.packed                         # -> bytes (msgpack)
restored = SensorReading.from_bytes(packed)      # -> SensorReading

assert original.sensor_id == restored.sensor_id
assert original.temperature == restored.temperature
assert original.anomaly == restored.anomaly
assert np.allclose(original.waveform, restored.waveform)

print(f"Packed size: {len(packed):,} bytes")
print("Round-trip: OK")
Packed size: 2,124 bytes
Round-trip: OK

Numpy scalars (like np.float64 from rng.normal()) are automatically coerced to Python natives during serialization, so temperature comes back as a plain float:

print(f"Temperature type (original)  : {type(original.temperature).__name__}")
print(f"Temperature type (restored)  : {type(restored.temperature).__name__}")
assert isinstance(restored.temperature, float)
Temperature type (original)  : float
Temperature type (restored)  : float

4 — Write sharded tar files

write_samples serializes samples to WebDataset tar archives. Setting maxcount splits the output across multiple shards—essential for parallel I/O at scale.

import tempfile
from pathlib import Path

tmpdir = Path(tempfile.mkdtemp(prefix="atdata_example_"))
ds = atdata.write_samples(
    samples,
    tmpdir / "readings.tar",
    maxcount=500,       # 4 shards of 500 samples each
    manifest=True,      # generate per-shard metadata manifests
)

print(f"Dataset URL : {ds.url}")
print(f"Sample type : {ds.sample_type.__name__}")
print(f"Shard count : {len(ds.list_shards())}")
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_example_5y_x_15h/readings-000000.tar 0 0.0 GB 0
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_example_5y_x_15h/readings-000001.tar 500 0.0 GB 500
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_example_5y_x_15h/readings-000002.tar 500 0.0 GB 1000
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_example_5y_x_15h/readings-000003.tar 500 0.0 GB 1500
Dataset URL : /private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_example_5y_x_15h/readings-00000{0,1,2,3}.tar
Sample type : SensorReading
Shard count : 4

The returned Dataset[SensorReading] is immediately usable—no separate loading step required.

5 — Iterate with automatic batching

Dataset.ordered() streams samples in shard order. When you pass a batch_size, atdata wraps consecutive samples in a SampleBatch that aggregates fields automatically:

  • NDArray fields are stacked into a single array with a leading batch dimension.
  • Scalar/string fields become Python lists.
for batch in ds.ordered(batch_size=64):
    # NDArray field -> (64, 512) float32 array
    print(f"waveform    : {batch.waveform.shape}, {batch.waveform.dtype}")

    # Scalar fields -> plain lists
    print(f"sensor_id   : {batch.sensor_id[:3]}...")
    print(f"temperature : {[round(t, 1) for t in batch.temperature[:3]]}...")
    print(f"anomaly     : {batch.anomaly[:3]}...")
    break
waveform    : (64, 512), float32
sensor_id   : ['sensor_00', 'sensor_01', 'sensor_02']...
temperature : [20.3, 13.5, 19.2]...
anomaly     : [True, False, False]...

6 — Shuffled iteration for training

For model training you want randomized order. shuffled() applies two-level shuffling (shard order + in-buffer sample shuffling) to give well-mixed batches while staying streaming-friendly.

sensor_ids_seen: set[str] = set()

for batch in ds.shuffled(batch_size=128):
    sensor_ids_seen.update(batch.sensor_id)
    if len(sensor_ids_seen) == 8:
        break

print(f"Unique sensors in early batches: {sorted(sensor_ids_seen)}")
Unique sensors in early batches: ['sensor_00', 'sensor_01', 'sensor_02', 'sensor_03', 'sensor_04', 'sensor_05', 'sensor_06', 'sensor_07']

7 — Clean up

import shutil

shutil.rmtree(tmpdir, ignore_errors=True)

Key takeaways

Concept API
Type-safe samples @atdata.packable
Automatic serialization .packed / .from_bytes()
Sharded writes with manifests atdata.write_samples(..., maxcount=N, manifest=True)
Numpy scalar coercion np.float64float automatically
Ordered iteration ds.ordered(batch_size=N)
Shuffled iteration ds.shuffled(batch_size=N)
Batch aggregation batch.field (NDArray stacking or list)