grain.experimental.InterleaveIterDataset

grain.experimental.InterleaveIterDataset#

class grain.experimental.InterleaveIterDataset(datasets, *, cycle_length)#

Interleaves the given sequence of datasets.

The sequence can be a MapDataset.

Creates at most cycle_length iterators at a time that are processed concurrently and interleives their elements. If cycle_length is larger than the number of datasets, then the behavior is similar to mixing the datasets with equal proportions. If cycle_length is 1, the datasets are chained.

Can be used with mp_prefetch to parallelize reading from sources that do not support random access and are implemented as IterDataset:

def make_source(filename: str) -> grain.IterDataset:
  ...

ds = grain.MapDataset.source(filenames).shuffle(seed=42).map(make_source)
ds = grain.experimental.InterleaveIterDataset(ds, cycle_length=4)
ds = ...
ds = ds.mp_prefetch(ds, 2)
for element in ds:
  ...
Parameters:
__init__(datasets, *, cycle_length)#
Parameters:

Methods

__init__(datasets, *, cycle_length)

batch(batch_size, *[, drop_remainder, batch_fn])

Returns a dataset of elements batched along a new first dimension.

filter(transform)

Returns a dataset containing only the elements that match the filter.

map(transform)

Returns a dataset containing the elements transformed by transform.

map_with_index(transform)

Returns a dataset of the elements transformed by the transform.

mp_prefetch([options, worker_init_fn])

Returns a dataset prefetching elements in multiple processes.

pipe(func, /, *args, **kwargs)

Syntactic sugar for applying a callable to this dataset.

prefetch(multiprocessing_options)

Deprecated, use mp_prefetch instead.

random_map(transform, *[, seed])

Returns a dataset containing the elements transformed by transform.

seed(seed)

Returns a dataset that uses the seed for default seed generation.

set_slice(sl)

Attributes

parents