Dataset
Dataset(source=None, metadata_url=None, *, url=None)A typed dataset built on WebDataset with lens transformations.
This class wraps WebDataset tar archives and provides type-safe iteration over samples of a specific PackableSample type. Samples are stored as msgpack-serialized data within WebDataset shards.
The dataset supports: - Ordered and shuffled iteration - Automatic batching with SampleBatch - Type transformations via the lens system (as_type()) - Export to parquet format
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| ST | The sample type for this dataset, must derive from PackableSample. |
required |
Attributes
| Name | Type | Description |
|---|---|---|
| url | WebDataset brace-notation URL for the tar file(s). |
Examples
>>> ds = Dataset[MyData]("path/to/data-{000000..000009}.tar")
>>> for sample in ds.ordered(batch_size=32):
... # sample is SampleBatch[MyData] with batch_size samples
... embeddings = sample.embeddings # shape: (32, ...)
...
>>> # Transform to a different view
>>> ds_view = ds.as_type(MyDataView)Note
This class uses Python’s __orig_class__ mechanism to extract the type parameter at runtime. Instances must be created using the subscripted syntax Dataset[MyType](url) rather than calling the constructor directly with an unsubscripted class.
Methods
| Name | Description |
|---|---|
| as_type | View this dataset through a different sample type via a registered lens. |
| describe | Summary statistics: sample_type, fields, num_shards, shards, url, metadata. |
| filter | Return a new dataset that yields only samples matching predicate. |
| get | Retrieve a single sample by its __key__. |
| head | Return the first n samples from the dataset. |
| list_shards | Return all shard paths/URLs as a list. |
| map | Return a new dataset that applies fn to each sample during iteration. |
| ordered | Iterate over the dataset in order. |
| process_shards | Process each shard independently, collecting per-shard results. |
| query | Query this dataset using per-shard manifest metadata. |
| select | Return samples at the given integer indices. |
| shuffled | Iterate over the dataset in random order. |
| to_dict | Materialize the dataset as a column-oriented dictionary. |
| to_pandas | Materialize the dataset (or first limit samples) as a DataFrame. |
| to_parquet | Export dataset to parquet file(s). |
| wrap | Deserialize a raw WDS sample dict into type ST. |
| wrap_batch | Deserialize a raw WDS batch dict into SampleBatch[ST]. |
as_type
Dataset.as_type(other)View this dataset through a different sample type via a registered lens.
Raises
| Name | Type | Description |
|---|---|---|
| ValueError | If no lens exists between the current and target types. |
describe
Dataset.describe()Summary statistics: sample_type, fields, num_shards, shards, url, metadata.
filter
Dataset.filter(predicate)Return a new dataset that yields only samples matching predicate.
The filter is applied lazily during iteration — no data is copied.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| predicate | Callable[[ST], bool] | A function that takes a sample and returns True to keep it or False to discard it. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| Dataset[ST] | A new Dataset whose iterators apply the filter. |
Examples
>>> long_names = ds.filter(lambda s: len(s.name) > 10)
>>> for sample in long_names:
... assert len(sample.name) > 10get
Dataset.get(key)Retrieve a single sample by its __key__.
Scans shards sequentially until a sample with a matching key is found. This is O(n) for streaming datasets.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| key | str | The WebDataset __key__ string to search for. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| ST | The matching sample. |
Raises
| Name | Type | Description |
|---|---|---|
| SampleKeyError | If no sample with the given key exists. |
Examples
>>> sample = ds.get("00000001-0001-1000-8000-010000000000")head
Dataset.head(n=5)Return the first n samples from the dataset.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| n | int | Number of samples to return. Default: 5. | 5 |
Returns
| Name | Type | Description |
|---|---|---|
| list[ST] | List of up to n samples in shard order. |
Examples
>>> samples = ds.head(3)
>>> len(samples)
3list_shards
Dataset.list_shards()Return all shard paths/URLs as a list.
map
Dataset.map(fn)Return a new dataset that applies fn to each sample during iteration.
The mapping is applied lazily during iteration — no data is copied.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| fn | Callable[[ST], Any] | A function that takes a sample of type ST and returns a transformed value. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| Dataset | A new Dataset whose iterators apply the mapping. |
Examples
>>> names = ds.map(lambda s: s.name)
>>> for name in names:
... print(name)ordered
Dataset.ordered(batch_size=None)Iterate over the dataset in order.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| batch_size | int | None | The size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension. |
None |
Returns
| Name | Type | Description |
|---|---|---|
| Iterable[ST] | Iterable[SampleBatch[ST]] | A data pipeline that iterates over the dataset in its original | |
| Iterable[ST] | Iterable[SampleBatch[ST]] | sample order. When batch_size is None, yields individual |
|
| Iterable[ST] | Iterable[SampleBatch[ST]] | samples of type ST. When batch_size is an integer, yields |
|
| Iterable[ST] | Iterable[SampleBatch[ST]] | SampleBatch[ST] instances containing that many samples. |
Examples
>>> for sample in ds.ordered():
... process(sample) # sample is ST
>>> for batch in ds.ordered(batch_size=32):
... process(batch) # batch is SampleBatch[ST]process_shards
Dataset.process_shards(fn, *, shards=None)Process each shard independently, collecting per-shard results.
Unlike :meth:map (which is lazy and per-sample), this method eagerly processes each shard in turn, calling fn with the full list of samples from that shard. If some shards fail, raises :class:~atdata._exceptions.PartialFailureError containing both the successful results and the per-shard errors.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| fn | Callable[[list[ST]], Any] | Function receiving a list of samples from one shard and returning an arbitrary result. | required |
| shards | list[str] | None | Optional list of shard identifiers to process. If None, processes all shards in the dataset. Useful for retrying only the failed shards from a previous PartialFailureError. |
None |
Returns
| Name | Type | Description |
|---|---|---|
| dict[str, Any] | Dict mapping shard identifier to fn’s return value for each shard. |
Raises
| Name | Type | Description |
|---|---|---|
| PartialFailureError | If at least one shard fails. The exception carries .succeeded_shards, .failed_shards, .errors, and .results for inspection and retry. |
Examples
>>> results = ds.process_shards(lambda samples: len(samples))
>>> # On partial failure, retry just the failed shards:
>>> try:
... results = ds.process_shards(expensive_fn)
... except PartialFailureError as e:
... retry = ds.process_shards(expensive_fn, shards=e.failed_shards)query
Dataset.query(where)Query this dataset using per-shard manifest metadata.
Requires manifests to have been generated during shard writing. Discovers manifest files alongside the tar shards, loads them, and executes a two-phase query (shard-level aggregate pruning, then sample-level parquet filtering).
The where argument accepts either a lambda/function that operates on a pandas DataFrame, or a Predicate built from the proxy DSL.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| where | Callable[[pd.DataFrame], pd.Series] | Predicate | Predicate function or Predicate object that selects matching rows from the per-sample manifest DataFrame. |
required |
Returns
| Name | Type | Description |
|---|---|---|
| list[SampleLocation] | List of SampleLocation for matching samples. |
Raises
| Name | Type | Description |
|---|---|---|
| FileNotFoundError | If no manifest files are found alongside shards. |
Examples
>>> locs = ds.query(where=lambda df: df["confidence"] > 0.9)
>>> len(locs)
42>>> Q = ds.fields
>>> locs = ds.query(where=(Q.confidence > 0.9))select
Dataset.select(indices)Return samples at the given integer indices.
Iterates through the dataset in order and collects samples whose positional index matches. This is O(n) for streaming datasets.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| indices | Sequence[int] | Sequence of zero-based indices to select. | required |
Returns
| Name | Type | Description |
|---|---|---|
| list[ST] | List of samples at the requested positions, in index order. |
Examples
>>> samples = ds.select([0, 5, 10])
>>> len(samples)
3shuffled
Dataset.shuffled(buffer_shards=100, buffer_samples=10000, batch_size=None)Iterate over the dataset in random order.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| buffer_shards | int | Number of shards to buffer for shuffling at the shard level. Larger values increase randomness but use more memory. Default: 100. | 100 |
| buffer_samples | int | Number of samples to buffer for shuffling within shards. Larger values increase randomness but use more memory. Default: 10,000. | 10000 |
| batch_size | int | None | The size of iterated batches. Default: None (unbatched). If None, iterates over one sample at a time with no batch dimension. |
None |
Returns
| Name | Type | Description |
|---|---|---|
| Iterable[ST] | Iterable[SampleBatch[ST]] | A data pipeline that iterates over the dataset in randomized order. | |
| Iterable[ST] | Iterable[SampleBatch[ST]] | When batch_size is None, yields individual samples of type |
|
| Iterable[ST] | Iterable[SampleBatch[ST]] | ST. When batch_size is an integer, yields SampleBatch[ST] |
|
| Iterable[ST] | Iterable[SampleBatch[ST]] | instances containing that many samples. |
Examples
>>> for sample in ds.shuffled():
... process(sample) # sample is ST
>>> for batch in ds.shuffled(batch_size=32):
... process(batch) # batch is SampleBatch[ST]to_dict
Dataset.to_dict(limit=None)Materialize the dataset as a column-oriented dictionary.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| limit | int | None | Maximum number of samples to include. None means all. |
None |
Returns
| Name | Type | Description |
|---|---|---|
| dict[str, list[Any]] | Dictionary mapping field names to lists of values (one entry | |
| dict[str, list[Any]] | per sample). |
Warning
With limit=None this loads the entire dataset into memory.
Examples
>>> d = ds.to_dict(limit=10)
>>> d.keys()
dict_keys(['name', 'embedding'])
>>> len(d['name'])
10to_pandas
Dataset.to_pandas(limit=None)Materialize the dataset (or first limit samples) as a DataFrame.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| limit | int | None | Maximum number of samples to include. None means all samples (may use significant memory for large datasets). |
None |
Returns
| Name | Type | Description |
|---|---|---|
| pd.DataFrame | A pandas DataFrame with one row per sample and columns matching | |
| pd.DataFrame | the sample fields. |
Warning
With limit=None this loads the entire dataset into memory.
Examples
>>> df = ds.to_pandas(limit=100)
>>> df.columns.tolist()
['name', 'embedding']to_parquet
Dataset.to_parquet(path, sample_map=None, maxcount=None, **kwargs)Export dataset to parquet file(s).
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| path | Pathlike | Output path. With maxcount, files are named {stem}-{segment:06d}.parquet. |
required |
| sample_map | Optional[SampleExportMap] | Convert sample to dict. Defaults to dataclasses.asdict. |
None |
| maxcount | Optional[int] | Split into files of at most this many samples. Without it, the entire dataset is loaded into memory. | None |
| **kwargs | Passed to pandas.DataFrame.to_parquet(). |
{} |
Examples
>>> ds.to_parquet("output.parquet", maxcount=50000)wrap
Dataset.wrap(sample)Deserialize a raw WDS sample dict into type ST.
wrap_batch
Dataset.wrap_batch(batch)Deserialize a raw WDS batch dict into SampleBatch[ST].