Advanced Dataset
usage#
If you decided to use Dataset
APIs, there’s a good chance you want to do one or more processing steps described in this section, especially if working on data ingestion for generative model training.
# @test {"output": "ignore"}
!pip install grain
# @test {"output": "ignore"}
!pip install tensorflow_datasets
import grain
import numpy as np
import tensorflow_datasets as tfds
from pprint import pprint
Checkpointing#
We provide Checkpoint{Save|Restore}
to checkpoint the
DatasetIterator
. It is recommended to use it with
Orbax, which can checkpoint
both, input pipeline and model, and handles the edge cases for distributed
training.
ds = (
grain.MapDataset.source(tfds.data_source("mnist", split="train"))
.seed(seed=45)
.shuffle()
.to_iter_dataset()
)
num_steps = 4
ds_iter = iter(ds)
# Read some elements.
for i in range(num_steps):
x = next(ds_iter)
print(i, x["label"])
0 7
1 4
2 0
3 1
# @test {"output": "ignore"}
!pip install orbax
import orbax.checkpoint as ocp
mngr = ocp.CheckpointManager("/tmp/orbax")
!rm -rf /tmp/orbax
# Save the checkpoint.
assert mngr.save(
step=num_steps, args=grain.checkpoint.CheckpointSave(ds_iter), force=True
)
# Checkpoint saving in Orbax is asynchronous by default, so we'll wait until
# finished before examining checkpoint.
mngr.wait_until_finished()
# @test {"output": "ignore"}
!ls -R /tmp/orbax
/tmp/orbax:
4
/tmp/orbax/4:
_CHECKPOINT_METADATA
default
/tmp/orbax/4/default:
process_0-of-1.json
!cat /tmp/orbax/*/*/*.json
{
"next_index": 4
}
# Read more elements and advance the iterator.
for i in range(4, 8):
x = next(ds_iter)
print(i, x["label"])
4 7
5 4
6 8
7 0
# Restore iterator from the previously saved checkpoint.
mngr.restore(num_steps, args=grain.checkpoint.CheckpointRestore(ds_iter))
# Iterator should be set back to start from 4.
for i in range(4, 8):
x = next(ds_iter)
print(i, x["label"])
4 7
5 4
6 8
7 0
Mixing datasets#
Dataset
allows mixing multiple data sources with potentially different transformations. There’s two different ways of mixing Dataset
s: MapDataset.mix
and IterDataset.mix
. If the mixed Datasets
are sparse (e.g. one of the mixture components needs to be filtered) use IterDataset.mix
, otherwise use MapDataset.mix
.
tfds.core.DatasetInfo.file_format = (
tfds.core.file_adapters.FileFormat.ARRAY_RECORD
)
# This particular dataset mixes medical images with hand written numbers,
# probably not useful but allows to illustrate the API on small datasets.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).map(lambda features: features["image"])
ds2 = grain.MapDataset.source(source2).map(lambda features: features["image"])
ds = grain.MapDataset.mix([ds1, ds2], weights=[0.7, 0.3])
print(f"Mixed dataset length = {len(ds)}")
pprint(np.shape(ds[0]))
Mixed dataset length = 6728
(28, 28, 1)
If filtering inputs to the mixture, use IterDataset.mix
.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = (
grain.MapDataset.source(source1)
.filter(lambda features: int(features["label"]) == 1)
.to_iter_dataset()
)
ds2 = (
grain.MapDataset.source(source2)
.filter(lambda features: int(features["label"]) > 4)
.to_iter_dataset()
)
ds = grain.IterDataset.mix([ds1, ds2], weights=[0.7, 0.3]).map(
lambda features: features["image"]
)
pprint(np.shape(next(iter(ds))))
(28, 28, 1)
Multi-epoch training#
Mixed dataset length is determined by a combination of the length of the shortest input dataset and mixing weights. This means that once the shortest component is exhausted the new epoch will begin and the remainder of other datasets is going to be discarded. This can be avoided by repeating inputs to the mixture.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).repeat()
ds2 = grain.MapDataset.source(source2).repeat()
ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds1)}") # sys.maxsize
print(f"Mixed dataset length = {len(ds2)}") # sys.maxsize
# Ds1 and ds2 are repeated to fill out the sys.maxsize with respect to weights.
print(f"Mixed dataset length = {len(ds)}") # sys.maxsize
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807
Mixed dataset length = 9223372036854775807
Shuffling#
If you need to globally shuffle the mixed data prefer shuffling individual
Dataset
s before mixing. This will ensure that the actual weights of the mixed
Dataset
s are stable and as close as possible to the provided weights.
Additionally, make sure to provide different seeds to different mixture components. This way there’s no chance of introducing a seed dependency between the components if the random transformations overlap.
source1 = tfds.data_source(name="pneumonia_mnist", split="train")
source2 = tfds.data_source(name="mnist", split="train")
ds1 = grain.MapDataset.source(source1).seed(42).shuffle().repeat()
ds2 = grain.MapDataset.source(source2).seed(43).shuffle().repeat()
ds = grain.MapDataset.mix([ds1, ds2], weights=[1, 2])
print(f"Mixed dataset length = {len(ds)}") # sys.maxsize
Mixed dataset length = 9223372036854775807
Prefetching#
Grain offers prefetching mechanisms for potential performance improvements.
Thread prefetching#
ThreadPrefetchIterDataset
allows to process the buffer of size
cpu_buffer_size
on the CPU ahead of time.
import grain
import jax
import tensorflow_datasets as tfds
cpu_buffer_size = 3
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
ds.map(lambda x: x) # Dummy map to illustrate the usage.
ds = grain.experimental.ThreadPrefetchIterDataset(ds, prefetch_buffer_size=cpu_buffer_size)
ds = ds.map(jax.device_put)
grain.experimental.device_put
allows for processing the buffer of size
cpu_buffer_size on the CPU ahead of time and transferring the buffer of size
tpu_buffer_size on the device which can be jax.Device
or
jax.sharding.Sharding
.
import grain
import jax
import numpy as np
cpu_buffer_size = 3
tpu_buffer_size = 2
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
ds.map(lambda x: x) # Dummy map to illustrate the usage.
devices = jax.devices()
mesh = jax.sharding.Mesh(np.array(devices), axis_names=('data',))
p = jax.sharding.PartitionSpec('data')
sharding = jax.sharding.NamedSharding(mesh, p)
ds = grain.experimental.device_put(
ds=ds,
device=sharding,
cpu_buffer_size=cpu_buffer_size,
device_buffer_size=tpu_buffer_size,
)
Multithread prefetching#
PrefetchIterDataset
allows to use the pool of threads to prefetch the buffer
(defined by ReadOptions
) while supporting random access.
import grain
import jax
import numpy as np
# If not set defaults to 16 threads and buffer 500.
read_options = grain.ReadOptions(num_threads=32, prefetch_buffer_size=400)
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset(read_options=read_options)
Multithread prefetch Autotune#
PrefetchIterDataset
(invoked via to_iter_dataset
in the example) can
leverage the autotuning feature to automatically choose the buffer size based
on the user provided RAM memory constraint and dataset.
import grain
import jax
import numpy as np
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source)
performance_config = grain.experimental.pick_performance_config(
ds=ds,
ram_budget_mb=1024,
max_workers=None,
max_buffer_size=None
)
ds = ds.to_iter_dataset(read_options=performance_config.read_options)
Multiprocess Prefetch#
MultiprocessPrefetchIterDataset
allows to process the IterDataset in parallel
on multiple processes. The MultiprocessingOptions
allows to specify
num_workers
, per_worker_buffer_size
, enable_profiling
.
Multiple processes can speed up the pipeline if it’s compute bound and bottlenecked on the CPython’s GIL. The default value of 0 means no Python multiprocessing, and as a result all data loading and transformation will run in the main Python process.
per_worker_buffer_size
: Size of the buffer for preprocessed elements that each
worker maintains. These are elements after all transformations. If your
transformations include batching this means a single element is a batch.
import grain
import tensorflow_datasets as tfds
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
prefetch_lazy_iter_ds = ds.mp_prefetch(
grain.MultiprocessingOptions(num_workers=3, per_worker_buffer_size=10),
)
Multiprocess Prefetch Autotune#
MultiprocessPrefetchIterDataset
can leverage the autotuning feature to
automatically choose the number of workers based on the user provided RAM memory
constraint and dataset. Note that the number of workers in the config may change
depending on the hardware and in order to preserve Grain determinism the
recommendation is to store config in the persistent file system and pass it to
the pipeline.
import grain
import tensorflow_datasets as tfds
source = tfds.data_source(name="mnist", split="train")
ds = grain.MapDataset.source(source).to_iter_dataset()
performance_config = grain.experimental.pick_performance_config(
ds=ds,
ram_budget_mb=1024,
max_workers=None,
max_buffer_size=None
)
prefetch_lazy_iter_ds = ds.mp_prefetch(
performance_config.multiprocessing_options,
)