Dataset

Dataset(source=None, metadata_url=None, *, url=None)

A typed dataset built on WebDataset with lens transformations.

This class wraps WebDataset tar archives and provides type-safe iteration over samples of a specific PackableSample type. Samples are stored as msgpack-serialized data within WebDataset shards.

The dataset supports: - Ordered and shuffled iteration - Automatic batching with SampleBatch - Type transformations via the lens system (as_type()) - Export to parquet format

Parameters

Name Type Description Default
ST The sample type for this dataset, must derive from PackableSample. required

Attributes

Name Type Description
url WebDataset brace-notation URL for the tar file(s).

Examples

>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
>>> for sample in ds.ordered(batch_size=32):
...     # sample is SampleBatch[MyData] with batch_size samples
...     embeddings = sample.embeddings  # shape: (32, ...)
...
>>> # Transform to a different view
>>> ds_view = ds.as_type(MyDataView)

Note

This class uses Python’s __orig_class__ mechanism to extract the type parameter at runtime. Instances must be created using the subscripted syntax Dataset[MyType](url) rather than calling the constructor directly with an unsubscripted class.

Methods

Name Description
as_type View this dataset through a different sample type via a registered lens.
describe Summary statistics: sample_type, fields, num_shards, shards, url, metadata.
filter Return a new dataset that yields only samples matching predicate.
get Retrieve a single sample by its __key__.
head Return the first n samples from the dataset.
list_shards Return all shard paths/URLs as a list.
map Return a new dataset that applies fn to each sample during iteration.
ordered Iterate over the dataset in order.
process_shards Process each shard independently, collecting per-shard results.
query Query this dataset using per-shard manifest metadata.
select Return samples at the given integer indices.
shuffled Iterate over the dataset in random order.
to_dict Materialize the dataset as a column-oriented dictionary.
to_pandas Materialize the dataset (or first limit samples) as a DataFrame.
to_parquet Export dataset to parquet file(s).
wrap Deserialize a raw WDS sample dict into type ST.
wrap_batch Deserialize a raw WDS batch dict into SampleBatch[ST].

as_type

Dataset.as_type(other)

View this dataset through a different sample type via a registered lens.

Raises

Name Type Description
ValueError If no lens exists between the current and target types.

describe

Dataset.describe()

Summary statistics: sample_type, fields, num_shards, shards, url, metadata.

filter

Dataset.filter(predicate)

Return a new dataset that yields only samples matching predicate.

The filter is applied lazily during iteration — no data is copied.

Parameters

Name Type Description Default
predicate Callable[[ST], bool] A function that takes a sample and returns True to keep it or False to discard it. required

Returns

Name Type Description
Dataset[ST] A new Dataset whose iterators apply the filter.

Examples

>>> long_names = ds.filter(lambda s: len(s.name) > 10)
>>> for sample in long_names:
...     assert len(sample.name) > 10

get

Dataset.get(key)

Retrieve a single sample by its __key__.

Scans shards sequentially until a sample with a matching key is found. This is O(n) for streaming datasets.

Parameters

Name Type Description Default
key str The WebDataset __key__ string to search for. required

Returns

Name Type Description
ST The matching sample.

Raises

Name Type Description
SampleKeyError If no sample with the given key exists.

Examples

>>> sample = ds.get("00000001-0001-1000-8000-010000000000")

head

Dataset.head(n=5)

Return the first n samples from the dataset.

Parameters

Name Type Description Default
n int Number of samples to return. Default: 5. 5

Returns

Name Type Description
list[ST] List of up to n samples in shard order.

Examples

>>> samples = ds.head(3)
>>> len(samples)
3

list_shards

Dataset.list_shards()

Return all shard paths/URLs as a list.

map

Dataset.map(fn)

Return a new dataset that applies fn to each sample during iteration.

The mapping is applied lazily during iteration — no data is copied.

Parameters

Name Type Description Default
fn Callable[[ST], Any] A function that takes a sample of type ST and returns a transformed value. required

Returns

Name Type Description
Dataset A new Dataset whose iterators apply the mapping.

Examples

>>> names = ds.map(lambda s: s.name)
>>> for name in names:
...     print(name)

ordered

Dataset.ordered(batch_size=None)

Iterate over the dataset in order.

Parameters

Name Type Description Default
batch_size int | None The size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension. None

Returns

Name Type Description
Iterable[ST] | Iterable[SampleBatch[ST]] A data pipeline that iterates over the dataset in its original
Iterable[ST] | Iterable[SampleBatch[ST]] sample order. When batch_size is None, yields individual
Iterable[ST] | Iterable[SampleBatch[ST]] samples of type ST. When batch_size is an integer, yields
Iterable[ST] | Iterable[SampleBatch[ST]] SampleBatch[ST] instances containing that many samples.

Examples

>>> for sample in ds.ordered():
...     process(sample)  # sample is ST
>>> for batch in ds.ordered(batch_size=32):
...     process(batch)  # batch is SampleBatch[ST]

process_shards

Dataset.process_shards(fn, *, shards=None)

Process each shard independently, collecting per-shard results.

Unlike :meth:map (which is lazy and per-sample), this method eagerly processes each shard in turn, calling fn with the full list of samples from that shard. If some shards fail, raises :class:~atdata._exceptions.PartialFailureError containing both the successful results and the per-shard errors.

Parameters

Name Type Description Default
fn Callable[[list[ST]], Any] Function receiving a list of samples from one shard and returning an arbitrary result. required
shards list[str] | None Optional list of shard identifiers to process. If None, processes all shards in the dataset. Useful for retrying only the failed shards from a previous PartialFailureError. None

Returns

Name Type Description
dict[str, Any] Dict mapping shard identifier to fn’s return value for each shard.

Raises

Name Type Description
PartialFailureError If at least one shard fails. The exception carries .succeeded_shards, .failed_shards, .errors, and .results for inspection and retry.

Examples

>>> results = ds.process_shards(lambda samples: len(samples))
>>> # On partial failure, retry just the failed shards:
>>> try:
...     results = ds.process_shards(expensive_fn)
... except PartialFailureError as e:
...     retry = ds.process_shards(expensive_fn, shards=e.failed_shards)

query

Dataset.query(where)

Query this dataset using per-shard manifest metadata.

Requires manifests to have been generated during shard writing. Discovers manifest files alongside the tar shards, loads them, and executes a two-phase query (shard-level aggregate pruning, then sample-level parquet filtering).

The where argument accepts either a lambda/function that operates on a pandas DataFrame, or a Predicate built from the proxy DSL.

Parameters

Name Type Description Default
where Callable[[pd.DataFrame], pd.Series] | Predicate Predicate function or Predicate object that selects matching rows from the per-sample manifest DataFrame. required

Returns

Name Type Description
list[SampleLocation] List of SampleLocation for matching samples.

Raises

Name Type Description
FileNotFoundError If no manifest files are found alongside shards.

Examples

>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
>>> len(locs)
42
>>> Q = ds.fields
>>> locs = ds.query(where=(Q.confidence > 0.9))

select

Dataset.select(indices)

Return samples at the given integer indices.

Iterates through the dataset in order and collects samples whose positional index matches. This is O(n) for streaming datasets.

Parameters

Name Type Description Default
indices Sequence[int] Sequence of zero-based indices to select. required

Returns

Name Type Description
list[ST] List of samples at the requested positions, in index order.

Examples

>>> samples = ds.select([0, 5, 10])
>>> len(samples)
3

shuffled

Dataset.shuffled(buffer_shards=100, buffer_samples=10000, batch_size=None)

Iterate over the dataset in random order.

Parameters

Name Type Description Default
buffer_shards int Number of shards to buffer for shuffling at the shard level. Larger values increase randomness but use more memory. Default: 100. 100
buffer_samples int Number of samples to buffer for shuffling within shards. Larger values increase randomness but use more memory. Default: 10,000. 10000
batch_size int | None The size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension. None

Returns

Name Type Description
Iterable[ST] | Iterable[SampleBatch[ST]] A data pipeline that iterates over the dataset in randomized order.
Iterable[ST] | Iterable[SampleBatch[ST]] When batch_size is None, yields individual samples of type
Iterable[ST] | Iterable[SampleBatch[ST]] ST. When batch_size is an integer, yields SampleBatch[ST]
Iterable[ST] | Iterable[SampleBatch[ST]] instances containing that many samples.

Examples

>>> for sample in ds.shuffled():
...     process(sample)  # sample is ST
>>> for batch in ds.shuffled(batch_size=32):
...     process(batch)  # batch is SampleBatch[ST]

to_dict

Dataset.to_dict(limit=None)

Materialize the dataset as a column-oriented dictionary.

Parameters

Name Type Description Default
limit int | None Maximum number of samples to include. None means all. None

Returns

Name Type Description
dict[str, list[Any]] Dictionary mapping field names to lists of values (one entry
dict[str, list[Any]] per sample).

Warning

With limit=None this loads the entire dataset into memory.

Examples

>>> d = ds.to_dict(limit=10)
>>> d.keys()
dict_keys(['name', 'embedding'])
>>> len(d['name'])
10

to_pandas

Dataset.to_pandas(limit=None)

Materialize the dataset (or first limit samples) as a DataFrame.

Parameters

Name Type Description Default
limit int | None Maximum number of samples to include. None means all samples (may use significant memory for large datasets). None

Returns

Name Type Description
pd.DataFrame A pandas DataFrame with one row per sample and columns matching
pd.DataFrame the sample fields.

Warning

With limit=None this loads the entire dataset into memory.

Examples

>>> df = ds.to_pandas(limit=100)
>>> df.columns.tolist()
['name', 'embedding']

to_parquet

Dataset.to_parquet(path, sample_map=None, maxcount=None, **kwargs)

Export dataset to parquet file(s).

Parameters

Name Type Description Default
path Pathlike Output path. With maxcount, files are named {stem}-{segment:06d}.parquet. required
sample_map Optional[SampleExportMap] Convert sample to dict. Defaults to dataclasses.asdict. None
maxcount Optional[int] Split into files of at most this many samples. Without it, the entire dataset is loaded into memory. None
**kwargs Passed to pandas.DataFrame.to_parquet(). {}

Examples

>>> ds.to_parquet("output.parquet", maxcount=50000)

wrap

Dataset.wrap(sample)

Deserialize a raw WDS sample dict into type ST.

wrap_batch

Dataset.wrap_batch(batch)

Deserialize a raw WDS batch dict into SampleBatch[ST].