Dataset
basics#
Dataset
is a low-level API that uses chaining syntax to define data
transformation steps. It allows more general types of processing (e.g. dataset
mixing) and more control over the execution (e.g. different order of data
sharding and shuffling). Dataset
transformations are composed in a way that
allows to preserve random access property past the source and some of the
transformations. This, among other things, can be used for debugging by
evaluating dataset elements at specific positions without processing the entire
dataset.
There are 3 main classes comprising the Dataset
API: MapDataset
,
IterDataset
, and DatasetIterator
. Most data pipelines will start with one or
more MapDataset
(often derived from a RandomAccessDataSource
) and switch to
IterDataset
late or not at all. The following sections will provide more
details about each class.
Install and import Grain#
# @test {"output": "ignore"}
!pip install grain
from pprint import pprint
import grain
MapDataset
#
MapDataset
defines a dataset that supports efficient random access. Think of
it as an (infinite) Sequence
that computes values lazily. It will either be
the starting point of the input pipeline or in the middle of the pipeline
following another MapDataset
. Grain provides many basic transformations for
users to get started.
dataset = (
# You can also use a shortcut grain.MapDataset.range for
# range-like input.
grain.MapDataset.source([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
.shuffle(seed=10) # Shuffles globally.
.map(lambda x: x + 1) # Maps each element.
.batch(batch_size=2) # Batches consecutive elements.
)
pprint(dataset[0])
pprint(list(dataset))
array([ 6, 11])
[array([ 6, 11]),
array([ 2, 10]),
array([3, 5]),
array([1, 4]),
array([8, 9]),
array([7])]
The requirement for MapDataset
’s source is a grain.RandomAccessDataSource
interface: i.e. __getitem__
and __len__
.
# Note: Inheriting `grain.RandomAccessDataSource` is optional but recommended.
class MySource(grain.sources.RandomAccessDataSource):
def __init__(self):
self._data = [0, 1, 2, 3, 4, 5, 6, 7]
def __getitem__(self, idx):
return self._data[idx]
def __len__(self):
return len(self._data)
source = MySource()
dataset = (
grain.MapDataset.source(source)
.shuffle(seed=10) # Shuffles globally.
.map(lambda x: x + 1) # Maps each element.
.batch(batch_size=2) # Batches consecutive elements.
)
pprint(dataset[0])
pprint(list(dataset))
array([6, 7])
[array([6, 7]), array([2, 8]), array([3, 5]), array([1, 4])]
Access by index will never raise an IndexError
and can treat indices that are
equal or larger than the length as a different epoch (e.g. shuffle differently,
use different random numbers).
# Prints the 3rd element of the second epoch.
pprint(dataset[len(dataset) + 2])
array([7, 3])
Note that dataset[idx] == dataset[len(dataset) + idx]
iff there’s no random
transfomations. Since dataset
has global shuffle, different epochs are
shuffled differently:
pprint(dataset[len(dataset) + 2] == dataset[2])
array([False, False])
You can use filter
to remove elements not needed but it will return None
to
indicate that there is no element at the given index.
Returning None
for the majority of positions can negatively impact performance
of the pipeline. For example, if your pipeline filters 90% of the data it might
be better to store a filtered version of your dataset.
filtered_dataset = dataset.filter(lambda e: (e[0] + e[1]) % 2 == 0)
pprint(f"Length of this dataset: {len(filtered_dataset)}")
pprint([filtered_dataset[i] for i in range(len(filtered_dataset))])
'Length of this dataset: 4'
[None, array([2, 8]), array([3, 5]), None]
MapDataset
also supports slicing using the same syntax as Python lists. This
returns a MapDataset
representing the sliced section. Slicing is the easiest
way to “shard” data during distributed training.
shard_index = 0
shard_count = 2
sharded_dataset = dataset[shard_index::shard_count]
print(f"Sharded dataset length = {len(sharded_dataset)}")
pprint(sharded_dataset[0])
pprint(sharded_dataset[1])
Sharded dataset length = 2
array([6, 7])
array([3, 5])
For the actual running training with the dataset, we should convert MapDataset
into IterDataset
to leverage parallel prefetching to hide the latency of each
element’s IO using Python threads.
This brings us to the next section of the tutorial: IterDataset
.
iter_dataset = sharded_dataset.to_iter_dataset(
grain.ReadOptions(num_threads=16, prefetch_buffer_size=500)
)
for element in iter_dataset:
pprint(element)
array([6, 7])
array([3, 5])
IterDataset#
Most data pipelines will start with one or more MapDataset
(often derived from
a RandomAccessDataSource
) and switch to IterDataset
late or not at all.
IterDataset
does not support efficient random access and only supports
iterating over it. It’s an Iterable
.
Any MapDataset
can be turned into a IterDataset
by calling
to_iter_dataset
. When possible this should happen late in the pipeline since
it will restrict the transformations that can come after it (e.g. global shuffle
must come before). This conversion by default skips None
elements.
Some transformations have implementations for both, MapDataset
and
IterDataset
, e.g. filter
, map
, random_map
, batch
. They produce the
same result with one caveat: MapDataset.batch
cannot follow
MapDataset.filter
- you will need to convert to IterDataset
before applying
batch
:
ds = (
grain.MapDataset.range(10)
.filter(lambda x: x % 2 == 0)
.to_iter_dataset()
.batch(2) # Calling `batch` before `to_iter_dataset` will raise an error.
)
pprint(list(ds))
[array([0, 2]), array([4, 6]), array([8])]
DatasetIterator
is a stateful iterator of IterDataset
. The state of the
iterator can be cheaply saved and restored. This is intended for checkpointing
the input pipeline together with the trained model. The returned state will not
contain data that flows through the pipeline.
Essentially, DatasetIterator
only checkpoints index information for it to
recover (assuming the underlying content of files will not change).
dataset_iter = iter(dataset)
pprint(isinstance(dataset_iter, grain.DatasetIterator))
True
pprint(next(dataset_iter))
checkpoint = dataset_iter.get_state()
pprint(next(dataset_iter))
# Recover the iterator to the state after the first produced element.
dataset_iter.set_state(checkpoint)
pprint(next(dataset_iter)) # This should generate the same element as above
array([6, 7])
array([2, 8])
array([2, 8])