Dataset basics#

Open in Colab

Dataset is a low-level API that uses chaining syntax to define data transformation steps. It allows more general types of processing (e.g. dataset mixing) and more control over the execution (e.g. different order of data sharding and shuffling). Dataset transformations are composed in a way that allows to preserve random access property past the source and some of the transformations. This, among other things, can be used for debugging by evaluating dataset elements at specific positions without processing the entire dataset.

There are 3 main classes comprising the Dataset API: MapDataset, IterDataset, and DatasetIterator. Most data pipelines will start with one or more MapDataset (often derived from a RandomAccessDataSource) and switch to IterDataset late or not at all. The following sections will provide more details about each class.

Install and import Grain#

# @test {"output": "ignore"}
!pip install grain
from pprint import pprint
import grain

MapDataset#

MapDataset defines a dataset that supports efficient random access. Think of it as an (infinite) Sequence that computes values lazily. It will either be the starting point of the input pipeline or in the middle of the pipeline following another MapDataset. Grain provides many basic transformations for users to get started.

dataset = (
    # You can also use a shortcut grain.MapDataset.range for
    # range-like input.
    grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .shuffle(seed=10)  # Shuffles globally.
    .map(lambda x: x + 1)  # Maps each element.
    .batch(batch_size=2)  # Batches consecutive elements.
)

pprint(dataset[0])
pprint(list(dataset))
array([ 6, 11])
[array([ 6, 11]),
 array([ 2, 10]),
 array([3, 5]),
 array([1, 4]),
 array([8, 9]),
 array([7])]

The requirement for MapDataset’s source is a grain.RandomAccessDataSource interface: i.e. __getitem__ and __len__.

# Note: Inheriting `grain.RandomAccessDataSource` is optional but recommended.
class MySource(grain.sources.RandomAccessDataSource):

  def __init__(self):
    self._data = [0, 1, 2, 3, 4, 5, 6, 7]

  def __getitem__(self, idx):
    return self._data[idx]

  def __len__(self):
    return len(self._data)
source = MySource()

dataset = (
    grain.MapDataset.source(source)
    .shuffle(seed=10)  # Shuffles globally.
    .map(lambda x: x + 1)  # Maps each element.
    .batch(batch_size=2)  # Batches consecutive elements.
)

pprint(dataset[0])
pprint(list(dataset))
array([6, 7])
[array([6, 7]), array([2, 8]), array([3, 5]), array([1, 4])]

Access by index will never raise an IndexError and can treat indices that are equal or larger than the length as a different epoch (e.g. shuffle differently, use different random numbers).

# Prints the 3rd element of the second epoch.
pprint(dataset[len(dataset) + 2])
array([7, 3])

Note that dataset[idx] == dataset[len(dataset) + idx] iff there’s no random transfomations. Since dataset has global shuffle, different epochs are shuffled differently:

pprint(dataset[len(dataset) + 2] == dataset[2])
array([False, False])

You can use filter to remove elements not needed but it will return None to indicate that there is no element at the given index.

Returning None for the majority of positions can negatively impact performance of the pipeline. For example, if your pipeline filters 90% of the data it might be better to store a filtered version of your dataset.

filtered_dataset = dataset.filter(lambda e: (e[0] + e[1]) % 2 == 0)

pprint(f"Length of this dataset: {len(filtered_dataset)}")
pprint([filtered_dataset[i] for i in range(len(filtered_dataset))])
'Length of this dataset: 4'
[None, array([2, 8]), array([3, 5]), None]

MapDataset also supports slicing using the same syntax as Python lists. This returns a MapDataset representing the sliced section. Slicing is the easiest way to “shard” data during distributed training.

shard_index = 0
shard_count = 2

sharded_dataset = dataset[shard_index::shard_count]
print(f"Sharded dataset length = {len(sharded_dataset)}")
pprint(sharded_dataset[0])
pprint(sharded_dataset[1])
Sharded dataset length = 2
array([6, 7])
array([3, 5])

For the actual running training with the dataset, we should convert MapDataset into IterDataset to leverage parallel prefetching to hide the latency of each element’s IO using Python threads.

This brings us to the next section of the tutorial: IterDataset.

iter_dataset = sharded_dataset.to_iter_dataset(
    grain.ReadOptions(num_threads=16, prefetch_buffer_size=500)
)

for element in iter_dataset:
  pprint(element)
array([6, 7])
array([3, 5])

IterDataset#

Most data pipelines will start with one or more MapDataset (often derived from a RandomAccessDataSource) and switch to IterDataset late or not at all. IterDataset does not support efficient random access and only supports iterating over it. It’s an Iterable.

Any MapDataset can be turned into a IterDataset by calling to_iter_dataset. When possible this should happen late in the pipeline since it will restrict the transformations that can come after it (e.g. global shuffle must come before). This conversion by default skips None elements.

Some transformations have implementations for both, MapDataset and IterDataset, e.g. filter, map, random_map, batch. They produce the same result with one caveat: MapDataset.batch cannot follow MapDataset.filter - you will need to convert to IterDataset before applying batch:

ds = (
    grain.MapDataset.range(10)
    .filter(lambda x: x % 2 == 0)
    .to_iter_dataset()
    .batch(2)  # Calling `batch` before `to_iter_dataset` will raise an error.
)
pprint(list(ds))
[array([0, 2]), array([4, 6]), array([8])]

DatasetIterator is a stateful iterator of IterDataset. The state of the iterator can be cheaply saved and restored. This is intended for checkpointing the input pipeline together with the trained model. The returned state will not contain data that flows through the pipeline.

Essentially, DatasetIterator only checkpoints index information for it to recover (assuming the underlying content of files will not change).

dataset_iter = iter(dataset)
pprint(isinstance(dataset_iter, grain.DatasetIterator))
True
pprint(next(dataset_iter))

checkpoint = dataset_iter.get_state()

pprint(next(dataset_iter))

# Recover the iterator to the state after the first produced element.
dataset_iter.set_state(checkpoint)

pprint(next(dataset_iter))  # This should generate the same element as above
array([6, 7])
array([2, 8])
array([2, 8])