import numpy as np
from numpy.typing import NDArray
import atdata
@atdata.packable
class DigitSample:
"""A small grayscale digit image with label."""
image: NDArray # (28, 28) float32
label: int
split: strMulti-Split Datasets
Real-world ML workflows almost always have multiple splits—train, validation, test. atdata’s DatasetDict and load_dataset support this natively, following the HuggingFace convention of naming shards by split.
This example generates a toy classification dataset with train/test splits, writes them as separate shard sets, and loads everything back through the unified load_dataset API.
1 — Define a sample type
2 — Generate train and test data
rng = np.random.default_rng(12)
def make_samples(n: int, split: str) -> list[DigitSample]:
return [
DigitSample(
image=rng.random((28, 28), dtype=np.float32),
label=int(rng.integers(0, 10)),
split=split,
)
for _ in range(n)
]
train_samples = make_samples(1_000, "train")
test_samples = make_samples(200, "test")
print(f"Train: {len(train_samples)} samples")
print(f"Test : {len(test_samples)} samples")Train: 1000 samples
Test : 200 samples
3 — Write each split to its own shard set
Following the {split}-{shard} naming convention lets load_dataset detect splits automatically.
import tempfile
from pathlib import Path
tmpdir = Path(tempfile.mkdtemp(prefix="atdata_splits_"))
train_ds = atdata.write_samples(
train_samples,
tmpdir / "train.tar",
maxcount=500,
)
print(f"Train shards: {train_ds.list_shards()}")
test_ds = atdata.write_samples(
test_samples,
tmpdir / "test.tar",
maxcount=200,
)
print(f"Test shards : {test_ds.list_shards()}")# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000000.tar 0 0.0 GB 0
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000001.tar 500 0.0 GB 500
Train shards: ['/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000000.tar', '/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000001.tar']
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/test-000000.tar 0 0.0 GB 0
Test shards : ['/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/test-000000.tar']
4 — Load a single split
Pass a specific shard URL and split to load_dataset to get a single Dataset.
ds_train = atdata.load_dataset(
str(tmpdir / "train-{000000..000001}.tar"),
DigitSample,
split="train",
)
for batch in ds_train.ordered(batch_size=32):
print(f"Image batch shape : {batch.image.shape}")
print(f"Labels : {batch.label[:8]}...")
breakImage batch shape : (32, 28, 28)
Labels : [2, 2, 3, 4, 7, 9, 7, 1]...
5 — Build a DatasetDict from multiple splits
Wrap per-split Dataset objects into a DatasetDict for convenient multi-split access.
ds_dict = atdata.DatasetDict({
"train": atdata.Dataset[DigitSample](str(tmpdir / "train-{000000..000001}.tar")),
"test": atdata.Dataset[DigitSample](str(tmpdir / "test-000000.tar")),
})
print(f"Splits: {list(ds_dict.keys())}")
print(f"Train type: {type(ds_dict['train']).__name__}")
print(f"Test type : {type(ds_dict['test']).__name__}")Splits: ['train', 'test']
Train type: Dataset
Test type : Dataset
6 — Single-split proxy
When a DatasetDict has exactly one split, you can call Dataset methods directly on it—no need to index by split name first.
single = atdata.DatasetDict({
"train": atdata.Dataset[DigitSample](str(tmpdir / "train-{000000..000001}.tar")),
})
# .ordered() is proxied to the sole split automatically
for batch in single.ordered(batch_size=32):
print(f"Proxy batch shape: {batch.image.shape}")
breakProxy batch shape: (32, 28, 28)
7 — Iterate over each split
for split_name, split_ds in ds_dict.items():
count = 0
for batch in split_ds.ordered(batch_size=64):
count += len(batch.label)
print(f" {split_name:6s}: {count} samples iterated") train : 1000 samples iterated
test : 200 samples iterated
8 — Cross-split statistics
Because both splits share the same DigitSample type, you can compute comparable statistics directly.
for split_name, split_ds in ds_dict.items():
all_labels: list[int] = []
for batch in split_ds.ordered(batch_size=256):
all_labels.extend(batch.label)
label_arr = np.array(all_labels)
counts = np.bincount(label_arr, minlength=10)
print(f" {split_name:6s}: {len(all_labels)} samples, label distribution: {counts.tolist()}") train : 1000 samples, label distribution: [113, 85, 95, 100, 109, 91, 110, 100, 103, 94]
test : 200 samples, label distribution: [23, 20, 19, 23, 21, 20, 19, 19, 19, 17]
9 — Clean up
import shutil
shutil.rmtree(tmpdir, ignore_errors=True)Key takeaways
| Concept | API |
|---|---|
| Write split shards | write_samples(samples, "split.tar", maxcount=N) |
| Load single split | load_dataset("path/train-{000..N}.tar") |
| Build a DatasetDict | DatasetDict({"train": ds1, "test": ds2}) |
| Single-split proxy | ds_dict.ordered() when only one split |
| Iterate splits | for name, ds in ds_dict.items() |
| Typed across splits | Same @packable type works everywhere |