Multi-Split Datasets

Work with train/test splits via DatasetDict and load_dataset

Real-world ML workflows almost always have multiple splits—train, validation, test. atdata’s DatasetDict and load_dataset support this natively, following the HuggingFace convention of naming shards by split.

This example generates a toy classification dataset with train/test splits, writes them as separate shard sets, and loads everything back through the unified load_dataset API.

1 — Define a sample type

import numpy as np
from numpy.typing import NDArray
import atdata


@atdata.packable
class DigitSample:
    """A small grayscale digit image with label."""

    image: NDArray  # (28, 28) float32
    label: int
    split: str

2 — Generate train and test data

rng = np.random.default_rng(12)

def make_samples(n: int, split: str) -> list[DigitSample]:
    return [
        DigitSample(
            image=rng.random((28, 28), dtype=np.float32),
            label=int(rng.integers(0, 10)),
            split=split,
        )
        for _ in range(n)
    ]

train_samples = make_samples(1_000, "train")
test_samples = make_samples(200, "test")

print(f"Train: {len(train_samples)} samples")
print(f"Test : {len(test_samples)} samples")
Train: 1000 samples
Test : 200 samples

3 — Write each split to its own shard set

Following the {split}-{shard} naming convention lets load_dataset detect splits automatically.

import tempfile
from pathlib import Path

tmpdir = Path(tempfile.mkdtemp(prefix="atdata_splits_"))

train_ds = atdata.write_samples(
    train_samples,
    tmpdir / "train.tar",
    maxcount=500,
)
print(f"Train shards: {train_ds.list_shards()}")

test_ds = atdata.write_samples(
    test_samples,
    tmpdir / "test.tar",
    maxcount=200,
)
print(f"Test shards : {test_ds.list_shards()}")
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000000.tar 0 0.0 GB 0
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000001.tar 500 0.0 GB 500
Train shards: ['/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000000.tar', '/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/train-000001.tar']
# writing /var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/test-000000.tar 0 0.0 GB 0
Test shards : ['/private/var/folders/hx/9l078dds5z945qcv8j1hsnr00000gn/T/atdata_splits_j7hbnd7k/test-000000.tar']

4 — Load a single split

Pass a specific shard URL and split to load_dataset to get a single Dataset.

ds_train = atdata.load_dataset(
    str(tmpdir / "train-{000000..000001}.tar"),
    DigitSample,
    split="train",
)

for batch in ds_train.ordered(batch_size=32):
    print(f"Image batch shape : {batch.image.shape}")
    print(f"Labels            : {batch.label[:8]}...")
    break
Image batch shape : (32, 28, 28)
Labels            : [2, 2, 3, 4, 7, 9, 7, 1]...

5 — Build a DatasetDict from multiple splits

Wrap per-split Dataset objects into a DatasetDict for convenient multi-split access.

ds_dict = atdata.DatasetDict({
    "train": atdata.Dataset[DigitSample](str(tmpdir / "train-{000000..000001}.tar")),
    "test": atdata.Dataset[DigitSample](str(tmpdir / "test-000000.tar")),
})

print(f"Splits: {list(ds_dict.keys())}")
print(f"Train type: {type(ds_dict['train']).__name__}")
print(f"Test type : {type(ds_dict['test']).__name__}")
Splits: ['train', 'test']
Train type: Dataset
Test type : Dataset

6 — Single-split proxy

When a DatasetDict has exactly one split, you can call Dataset methods directly on it—no need to index by split name first.

single = atdata.DatasetDict({
    "train": atdata.Dataset[DigitSample](str(tmpdir / "train-{000000..000001}.tar")),
})

# .ordered() is proxied to the sole split automatically
for batch in single.ordered(batch_size=32):
    print(f"Proxy batch shape: {batch.image.shape}")
    break
Proxy batch shape: (32, 28, 28)

7 — Iterate over each split

for split_name, split_ds in ds_dict.items():
    count = 0
    for batch in split_ds.ordered(batch_size=64):
        count += len(batch.label)
    print(f"  {split_name:6s}: {count} samples iterated")
  train : 1000 samples iterated
  test  : 200 samples iterated

8 — Cross-split statistics

Because both splits share the same DigitSample type, you can compute comparable statistics directly.

for split_name, split_ds in ds_dict.items():
    all_labels: list[int] = []
    for batch in split_ds.ordered(batch_size=256):
        all_labels.extend(batch.label)

    label_arr = np.array(all_labels)
    counts = np.bincount(label_arr, minlength=10)
    print(f"  {split_name:6s}: {len(all_labels)} samples, label distribution: {counts.tolist()}")
  train : 1000 samples, label distribution: [113, 85, 95, 100, 109, 91, 110, 100, 103, 94]
  test  : 200 samples, label distribution: [23, 20, 19, 23, 21, 20, 19, 19, 19, 17]

9 — Clean up

import shutil

shutil.rmtree(tmpdir, ignore_errors=True)

Key takeaways

Concept API
Write split shards write_samples(samples, "split.tar", maxcount=N)
Load single split load_dataset("path/train-{000..N}.tar")
Build a DatasetDict DatasetDict({"train": ds1, "test": ds2})
Single-split proxy ds_dict.ordered() when only one split
Iterate splits for name, ds in ds_dict.items()
Typed across splits Same @packable type works everywhere