Advanced Dataset usage#

Open in Colab

If you decided to use Dataset APIs, there’s a good chance you want to do one or more processing steps described in this section, especially if working on data ingestion for generative model training.

# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
import grain
import numpy as np
import tensorflow_datasets as tfds
from pprint import pprint

Checkpointing#

We provide Checkpoint{Save|Restore} to checkpoint the DatasetIterator. It is recommended to use it with Orbax, which can checkpoint both, input pipeline and model, and handles the edge cases for distributed training.

ds = (
    grain.MapDataset.source(tfds.data_source("mnist", split="train"))
    .seed(seed=45)
    .shuffle()
    .to_iter_dataset()
)

num_steps = 4
ds_iter = iter(ds)

# Read some elements.
for i in range(num_steps):
  x = next(ds_iter)
  print(i, x["label"])
0 7
1 4
2 0
3 1
# @test {"output": "ignore"}
!pip install orbax
import orbax.checkpoint as ocp

mngr = ocp.CheckpointManager("/tmp/orbax")

!rm -rf /tmp/orbax

# Save the checkpoint.
assert mngr.save(
    step=num_steps, args=grain.checkpoint.CheckpointSave(ds_iter), force=True
)
# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until
# finished before examining checkpoint.
mngr.wait_until_finished()

# @test {"output": "ignore"}
!ls -R /tmp/orbax
/tmp/orbax:
4

/tmp/orbax/4:
_CHECKPOINT_METADATA
default

/tmp/orbax/4/default:
process_0-of-1.json
!cat /tmp/orbax/*/*/*.json
{
    "next_index": 4
}
# Read more elements and advance the iterator.
for i in range(4, 8):
  x = next(ds_iter)
  print(i, x["label"])
4 7
5 4
6 8
7 0
# Restore iterator from the previously saved checkpoint.
mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(ds_iter))
# Iterator should be set back to start from 4.
for i in range(4, 8):
  x = next(ds_iter)
  print(i, x["label"])
4 7
5 4
6 8
7 0

Mixing datasets#

Dataset allows mixing multiple data sources with potentially different transformations. There’s two different ways of mixing Datasets: MapDataset.mix and IterDataset.mix. If the mixed Datasets are sparse (e.g. one of the mixture components needs to be filtered) use IterDataset.mix, otherwise use MapDataset.mix.

tfds.core.DatasetInfo.file_format = (
    tfds.core.file_adapters.FileFormat.ARRAY_RECORD
)
# This particular dataset mixes medical images with hand written numbers,
# probably not useful but allows to illustrate the API on small datasets.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).map(lambda features: features["image"])
ds2 = grain.MapDataset.source(source2).map(lambda features: features["image"])
ds = grain.MapDataset.mix([ds1, ds2], weights=[0.7, 0.3])
print(f"Mixed dataset length = {len(ds)}")
pprint(np.shape(ds[0]))
Mixed dataset length = 6728
(28, 28, 1)

If filtering inputs to the mixture, use IterDataset.mix.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = (
    grain.MapDataset.source(source1)
    .filter(lambda features: int(features["label"]) == 1)
    .to_iter_dataset()
)
ds2 = (
    grain.MapDataset.source(source2)
    .filter(lambda features: int(features["label"]) > 4)
    .to_iter_dataset()
)

ds = grain.IterDataset.mix([ds1, ds2], weights=[0.7, 0.3]).map(
    lambda features: features["image"]
)
pprint(np.shape(next(iter(ds))))
(28, 28, 1)

Multi-epoch training#

Mixed dataset length is determined by a combination of the length of the shortest input dataset and mixing weights. This means that once the shortest component is exhausted the new epoch will begin and the remainder of other datasets is going to be discarded. This can be avoided by repeating inputs to the mixture.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).repeat()
ds2 = grain.MapDataset.source(source2).repeat()

ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds1)}")  # sys.maxsize
print(f"Mixed dataset length = {len(ds2)}")  # sys.maxsize
# Ds1 and ds2 are repeated to fill out the sys.maxsize with respect to weights.
print(f"Mixed dataset length = {len(ds)}")  # sys.maxsize
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807

Shuffling#

If you need to globally shuffle the mixed data prefer shuffling individual Datasets before mixing. This will ensure that the actual weights of the mixed Datasets are stable and as close as possible to the provided weights.

Additionally, make sure to provide different seeds to different mixture components. This way there’s no chance of introducing a seed dependency between the components if the random transformations overlap.

source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).seed(42).shuffle().repeat()
ds2 = grain.MapDataset.source(source2).seed(43).shuffle().repeat()

ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds)}")  # sys.maxsize
Mixed dataset length = 9223372036854775807

Prefetching#

Grain offers prefetching mechanisms for potential performance improvements.

Thread prefetching#

ThreadPrefetchIterDataset allows to process the buffer of size cpu_buffer_size on the CPU ahead of time.

import grain
import jax
import tensorflow_datasets as tfds

cpu_buffer_size = 3
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
ds.map(lambda x: x)  # Dummy map to illustrate the usage.
ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=cpu_buffer_size)
ds = ds.map(jax.device_put)

grain.experimental.device_put allows for processing the buffer of size cpu_buffer_size on the CPU ahead of time and transferring the buffer of size tpu_buffer_size on the device which can be jax.Device or jax.sharding.Sharding.

import grain
import jax
import numpy as np

cpu_buffer_size = 3
tpu_buffer_size = 2
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
ds.map(lambda x: x)  # Dummy map to illustrate the usage.

devices = jax.devices()

mesh = jax.sharding.Mesh(np.array(devices), axis_names=('data',))
p = jax.sharding.PartitionSpec('data')
sharding = jax.sharding.NamedSharding(mesh, p)

ds = grain.experimental.device_put(
        ds=ds,
        device=sharding,
        cpu_buffer_size=cpu_buffer_size,
        device_buffer_size=tpu_buffer_size,
    )

Multithread prefetching#

PrefetchIterDataset allows to use the pool of threads to prefetch the buffer (defined by ReadOptions) while supporting random access.

import grain
import jax
import numpy as np

# If not set defaults to 16 threads and buffer 500.
read_options = grain.ReadOptions(num_threads=32, prefetch_buffer_size=400)

source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset(read_options=read_options)

Multithread prefetch Autotune#

PrefetchIterDataset (invoked via to_iter_dataset in the example) can leverage the autotuning feature to automatically choose the buffer size based on the user provided RAM memory constraint and dataset.

import grain
import jax
import numpy as np

source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source)
performance_config = grain.experimental.pick_performance_config(
        ds=ds,
        ram_budget_mb=1024,
        max_workers=None,
        max_buffer_size=None
    )
ds = ds.to_iter_dataset(read_options=performance_config.read_options)

Multiprocess Prefetch#

MultiprocessPrefetchIterDataset allows to process the IterDataset in parallel on multiple processes. The MultiprocessingOptions allows to specify num_workers, per_worker_buffer_size, enable_profiling.

Multiple processes can speed up the pipeline if it’s compute bound and bottlenecked on the CPython’s GIL. The default value of 0 means no Python multiprocessing, and as a result all data loading and transformation will run in the main Python process.

per_worker_buffer_size: Size of the buffer for preprocessed elements that each worker maintains. These are elements after all transformations. If your transformations include batching this means a single element is a batch.

import grain
import tensorflow_datasets as tfds

source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()

prefetch_lazy_iter_ds = ds.mp_prefetch(
        grain.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=10),
    )

Multiprocess Prefetch Autotune#

MultiprocessPrefetchIterDataset can leverage the autotuning feature to automatically choose the number of workers based on the user provided RAM memory constraint and dataset. Note that the number of workers in the config may change depending on the hardware and in order to preserve Grain determinism the recommendation is to store config in the persistent file system and pass it to the pipeline.

import grain
import tensorflow_datasets as tfds

source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()

performance_config = grain.experimental.pick_performance_config(
        ds=ds,
        ram_budget_mb=1024,
        max_workers=None,
        max_buffer_size=None
    )

prefetch_lazy_iter_ds = ds.mp_prefetch(
        performance_config.multiprocessing_options,
    )