grain Dataset#

Dataset classes.

  • MapDataset defines a dataset that supports efficient random access. It has 3 important methods:

    • __len__() returns the length of a single epoch over the dataset.

    • __getitem__() returns an element at the given positive index. The “true” length of a MapDataset is infinite.

    • Individual dataset elements are only evaluated when calling __getitem__(). MapDataset s are stateless and will not hold elements.

  • IterDataset defines a dataset that does not support efficient random access but can be iterated over. A MapDataset can be turned into a IterDataset but going from IterDataset to MapDataset is as expensive as materializing the whole dataset.

  • DatasetIterator defines a stateful iterator of IterDataset. The state of the iterator can be saved and restored.

Using the interfaces defined in collections.abc you can think of MapDataset as (infinite) Sequence, IterDataset as Iterable and DatasetIterator as Iterator.

MapDataset is typically created by one of the factory methods in MapDatasetMeta (e.g. MapDataset.range(5)). IterDataset is either created by calling to_iter_dataset() on a MapDataset or by one of the factory methods in IterDatasetMeta (e.g. IterDataset.mix([...])).

List of Members#

class _src.python.dataset.dataset.MapDatasetMeta(name, bases, namespace, /, **kwargs)#

Metaclass for MapDataset containing factory transformations.

concatenate(datasets)#

Returns a dataset of elements from all input datasets.

Example usage:

ds1 = MapDataset.range(3)
ds2 = MapDataset.range(3, 8)
list(MapDataset.concatenate([ds1, ds2])) == [0, 1, 2, 3, 4, 5, 6, 7]
Parameters:

datasets (Sequence[MapDataset[T]]) – The datasets to concatenate.

Returns:

A MapDataset that represents a concatenation of the input datasets. The n-th epoch of the returned dataset will be the n-th epoch of the component datasets.

Return type:

MapDataset[T]

mix(datasets, weights=None)#

Returns a dataset that mixes input datasets with the given weights.

Length of the mixed dataset will be determined by the length of the shortest input dataset. If you need an infinite dateset consider repeating the input datasets before mixing.

If you need to shuffle the mixed dataset while preserving the correct proportions, you should shuffle the input datasets before mixing.

Example usage:

ds1 = MapDataset.range(5)
ds2 = MapDataset.range(7, 10)
list(MapDataset.mix([ds1, ds2])) == [0, 7, 1, 8, 2, 9]
Parameters:
  • datasets (Sequence[MapDataset[T]]) – The datasets to mix.

  • weights (Sequence[float] | None) – The weights to use for mixing. Defaults to uniform weights if not specified.

Returns:

A MapDataset that represents a mixture of the input datasets according to the given weights.

Return type:

MapDataset[T]

range(start, stop=None, step=1)#

Returns a dataset with a range of integers.

Input arguments are interpreted the same way as in Python built-in range:

  • range(n) => start=0, stop=n, step=1

  • range(m, n) => start=m, stop=n, step=1

  • range(m, n, p) => start=m, stop=n, step=p

The produced values are consistent with the built-in range function:

list(MapDataset.range(...)) == list(range(...))
Parameters:
  • start (int) – The start of the range.

  • stop (int | None) – The stop of the range.

  • step (int) – The step of the range.

Returns:

A MapDataset with a range of integers.

Return type:

MapDataset[int]

select_from_datasets(datasets, selection_map)#

Returns a dataset selected from the inputs accoridng to the given map.

Allows more general types of dataset mixing than mix.

Parameters:
  • datasets (Sequence[MapDataset[T]]) – The datasets to select from.

  • selection_map (DatasetSelectionMap) – Mapping from index within the mixed dataset to a selected dataset index and index within that dataset. Length of the resulting dataset will be determined by the length of the selection_map.

Returns:

A MapDataset that represents a mixture of the input datasets according to the given selection map.

Return type:

MapDataset[T]

source(source)#

Returns a dataset that wraps a data source supporting random access.

Example:

ds = MapDataset.source(ArrayRecordDataSource(paths))

Works with Sequence inputs as well:

list(MapDataset.source([1, 2, 3, 4, 5])) == [1, 2, 3, 4, 5]
Parameters:

source (Sequence[T] | RandomAccessDataSource[T]) – Data source supporting efficient random access.

Returns:

A MapDataset that wraps the data source and allows to chain other MapDataset transformations.

Return type:

MapDataset[T]

class grain.MapDataset(parents=())#

Bases: _Dataset, Generic[T]

Represents a dataset with transformations that support random access.

Transformations do not mutate the dataset object. Instead, they return a new dataset. MapDataset is immutable.

Note

MapDataset transformations such as .filter() use None to indicate absence of an element. Generally, the implementation of MapDataset transformations already handle None as a special case (e.g. by returning None as soon as __getitem__ sees None). This means the user-defined functions passed to the MapDataset transformations do not need to explicitly handle None s.

Parameters:

parents (MapDataset | Sequence[MapDataset])

abstractmethod __getitem__(index: slice) MapDataset[T]#
abstractmethod __getitem__(index: int) T | None

Returns the element for the index or None if missing.

__init__(parents=())#
Parameters:

parents (MapDataset | Sequence[MapDataset])

__iter__()#
Return type:

DatasetIterator[T]

abstractmethod __len__()#

Returns length of a single epoch of this dataset.

Return type:

int

batch(batch_size, *, drop_remainder=False, batch_fn=None)#

Returns a dataset of elements batched along a new first dimension.

Dataset elements are expected to be PyTrees.

Example usage:

ds = MapDataset.range(5)
ds = ds.batch(batch_size=2)
list(ds) == [np.ndarray([0, 1]), np.ndarray([2, 3]), np.ndarray([4])]
Parameters:
  • batch_size (int) – The number of elements to batch together.

  • drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.

  • batch_fn (Callable[[Sequence[T]], S] | None) – A function that takes a list of elements and returns a batch. Defaults to stacking the elements along a new first batch dimension.

Returns:

A dataset of elements with the same PyTree structure with leaves concatenated along the first dimension.

Return type:

MapDataset[S]

filter(transform)#

Returns a dataset containing only the elements that match the filter.

Accessing an element of the returned dataset using subscription (ds[i]) returns:

  • None if transform returned False

  • the element if transform returned True

Iterating over a filtered dataset skips None elements by default.

Example usage:

ds = MapDataset.range(5)
ds = ds.filter(lambda x: x % 2 == 0)
ds[2] == 2
ds[1] == None
list(ds) == [0, 2, 4]

NOTE: list(ds) converts the dataset to an IterDataset with to_iter_dataset() under the hood which by default skips None elements.

to_iter_dataset produces a warning when iterating through a filtered dataset if the filter removes more than 90% of the elements. You can adjust the threshold through grain.experimental.DatasetOptions.filter_warn_threshold_ratio used in WithOptionsIterDataset. In order to produce an exception in such case use filter_raise_threshold_ratio.

Parameters:

transform (Filter | Callable[[T], bool]) – Either a FilterTransform containing the filter method or a callable that takes an element and returns a boolean.

Returns:

A dataset of the same type containing only the elements for which the filter transform returns True.

Return type:

MapDataset[T]

map(transform)#

Returns a dataset containing the elements transformed by transform.

Example usage:

ds = MapDataset.range(5)
ds = ds.map(lambda x: x + 10)
list(ds) == [10, 11, 12, 13, 14]
Parameters:

transform (MapTransform | Callable[[T], S]) – Either a MapTransform containing the map method or a callable that takes an element and returns a new element.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

MapDataset[S]

map_with_index(transform)#

Returns a dataset containing the elements transformed by transform.

The transform is called with the index of the element within the dataset and the element itself.

Example usage:

ds = MapDataset.source(["a", "b", "c", "d"])
ds = ds.map_with_index(lambda i, x: x + str(i))
list(ds) == ["a0", "b1", "c2", "d3"]
Parameters:

transform (MapWithIndex | Callable[[int, T], S]) – Either a MapWithIndexTransform containing the map_with_index method or a callable that takes an index and an element and returns a new element.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

MapDataset[S]

property parents: Sequence[MapDataset]#
pipe(func, /, *args, **kwargs)#

Syntactic sugar for applying a callable to this dataset.

The pipe method, borrowed from pandas.DataFrame, is convenient because it allows for using method chaining syntax in an extensible fashion, with transformations that are not built-in methods on Dataset.

For example, suppose you want to shuffle a dataset within a window. Functionality for this is available in WindowShuffleMapDataset, but not as a method on MapDataset, e.g.:

dataset = (
    grain.experimental.WindowShuffleMapDataset(
        grain.MapDataset.range(1000),
        window_size=128,
        seed=0,
    )
    .batch(16)
)

This solution suffers from readability, because the shuffle transformation appears out of order from the data flow.

In contrast, with pipe you can write:

dataset = (
    grain.MapDataset.range(1000)
    .pipe(
        grain.experimental.WindowShuffleMapDataset,
        window_size=128,
        seed=0
    )
    .batch(16)
)
Parameters:
  • func (Callable[[...], T]) – The callable to apply to this dataset.

  • *args – Additional positional arguments to pass to the callable.

  • **kwargs – Keyword arguments to pass to the callable.

Returns:

The result of calling func(self, *args, **kwargs).

Return type:

T

random_map(transform, *, seed=None)#

Returns a dataset containing the elements transformed by transform.

The transform is called with the element and a np.random.Generator instance that should be used inside the transform to preserve determinism. The seed can be either provided explicitly or set via ds.seed(seed). Prefer the latter if you don’t need to control the random map seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline. The generator is seeded by a combination of the seed and the index of the element in the dataset.

NOTE: Avoid using the provided RNG outside of the transform function (e.g. by passing it to the next transformation along with the data). The RNG is going to be reused.

Example usage:

ds = MapDataset.range(5)
ds = ds.random_map(lambda x, rng: x + rng.integers(5, 10))
set(ds).issubset(set(range(5, 15)))
Parameters:
  • transform (RandomMapTransform | Callable[[T, numpy.random.Generator], S]) – Either a RandomMapTransform containing the random_map method or a callable that takes an element and a np.random.Generator and returns a new element.

  • seed (int | None) – An optional integer between 0 and 2**32-1 representing the seed used to initialize the random number generator used by transform. If you don’t need to control the shuffle seed individually, prefer setting the pipeline-level seed with``ds.seed(seed)`` instead.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

MapDataset[S]

repeat(num_epochs=None)#

Returns a dataset repeating the elements of this dataset multiple times.

Specifying None for num_epochs will repeat the dataset infinitely, and causes len(ds) to return sys.maxsize.

Since MapDataset allows accessing elements past len(ds) - 1 anyway (and uses the index modulo len(ds)), this transformation effectively only changes the length of the dataset.

Can not be called on an infinite dataset.

Example usage:

list(MapDataset.range(5).repeat(2)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
ds = MapDataset.range(5).repeat()
len(ds) == sys.maxsize
ds[11111] == 1
Parameters:

num_epochs (int | None) – Either a positive integer representing the number of times this dataset should be repeated or None to repeat infinitely.

Returns:

A dataset repeating the elements of this dataset multiple times.

Return type:

MapDataset[T]

seed(seed)#

Returns a dataset that uses the seed for default seed generation.

When default seed generation is enabled by calling ds.seed, every downstream random transformation will be automatically seeded with a unique seed by default. This simplifies seed management, making it easier to avoid:

  • Having to provide a seed in multiple transformations.

  • Accidentally reusing the same seed across transformations.

It is recommended to call this right after the source. ds.seed has to be called before any random transformations (such as shuffle or random_map that rely on default seed generation to control their seeding). Given the same seed, the pipeline is guaranteed to always use the same seeds for each transformation.

WARNING: The seed for random downstream transformations is derived from the seed passed to ds.seed and the absolute position of the transformation in the pipeline. This means that if you add transformations before the random transformation, its seed will change. For instance, if this random transformation is shuffle, adding a transformation before shuffle will change its seed and data order, consequently. To avoid this, pass the seed to the transformation directly.

Note about custom dataset implementations: the default seed generation is available through _default_seed, but the private API is not guaranteed to be stable.

Example 1:

ds = ds.seed(seed).shuffle()

shuffle will automatically derive its own seed (different from seed).

Example 2:

ds = ds.seed(seed).shuffle().random_map(...)

shuffle and random_map will each derive their own seed and the seeds are going to be different.

Example 3:

ds = ds.seed(seed).random_map(transform, seed=seed1)

random_map will use seed1 and will not be affected by seed. This can be used to control individual transformation seeding independently from the rest of the pipeline.

Example 4:

ds = ds.seed(seed1).shuffle().seed(seed2).random_map(...)

ds.seed only affects the downstream transformations and can be overridden by a subsequent seed call. shuffle will derive its seed from seed1, random_map - from seed2 and will not be affected by seed1. This can be used to control your transformation seeding even if you don’t own the first part of the pipeline.

Example 5:

ds1 = ds.source(...).seed(seed1).shuffle()
ds2 = ds.source(...).seed(seed2).shuffle()
ds = MapDataset.mix([ds1, ds2], ...).random_map(...)

Each shuffle will derive its own seed from seed1 or seed2 respectively. random_map will derive its seed from both seed1 and seed2.

Parameters:

seed (int) – Seed to use.

Returns:

A dataset with elements unchanged.

Return type:

MapDataset[T]

shuffle(seed=None)#

Returns a dataset with the same elements in a globally shuffled order.

The shuffle is deterministic and will always produce the same result given the same seed. The seed can be either provided explicitly or set via ds.seed(seed). Prefer the latter if you don’t need to control the shuffle seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline.

In multi-epoch training each epoch will be shuffled differently (i.e. the seed is combined with epoch number). In such case it is recommended to shuffle before repeat to avoid mixing elements from different epochs.

Example usage:

ds = MapDataset.range(5).shuffle()
set(ds) == {0, 1, 2, 3, 4}
list(ds) != [0, 1, 2, 3, 4]  # With probability (1 - 1/5!).
Parameters:

seed (int | None) – An optional integer between 0 and 2**32-1 representing the seed used by the shuffling algorithm. If you don’t need to control the shuffle seed individually, prefer setting the pipeline-level seed with ds.seed(seed) instead.

Returns:

A dataset containing the same elements but in a shuffled order.

Return type:

MapDataset[T]

slice(sl)#

Returns a dataset containing only the elements with indices in sl.

For most implementations of MapDataset slicing is also available through subscript operator: list(ds.slice(slice(1, 10, 2))) == ds[1:10:2].

Example usage:

ds = MapDataset.range(5)
list(ds.slice(slice(1, 3))) == [1, 2]
list(ds.slice(slice(1, None, 2))) == [1, 3]

Commonly used for sharding: ds = ds.slice(slice(shard_index, None, shard_count)), or, equivalently, ds = ds[shard_index::shard_count].

Parameters:

sl (slice) – A slice object (https://docs.python.org/3/library/functions.html#slice) representing the slice of elements to that should constitute the returned dataset.

Returns:

A dataset containing only the elements with indices in the sl slice.

Return type:

MapDataset[T]

to_iter_dataset(read_options=None, *, allow_nones=False)#

Converts this dataset to an IterDataset.

Elements from this dataset may be processed in multiple threads.

Note that some of the transformations are not available on IterDataset. These are roughly transformations operating on element index such as shuffle, map_with_index, slice and repeat.

Parameters:
  • read_options (ReadOptions | None) – Controls multithreading when reading the data and applying transformations in this dataset.

  • allow_nones (bool) – Whether to allow None values in the dataset (e.g. produced by filter). If False (the default), None values will be filtered out.

Returns:

An IterDataset with the same non- None elements as this dataset.

Return type:

IterDataset[T]

class _src.python.dataset.dataset.IterDatasetMeta(name, bases, namespace, /, **kwargs)#

Metaclass for IterDataset containing factory transformations.

mix(datasets, weights=None)#

Returns a dataset that mixes input datasets with the given weights.

NOTE: Stops producing elements once any input dataset is exhausted. If you need an infinite mixed dateset consider repeating the input datasets before mixing.

Example usage:

ds1 = MapDataset.range(5).to_iter_dataset()
ds2 = MapDataset.range(7, 10).to_iter_dataset()
list(IterDataset.mix([ds1, ds2])) == [0, 7, 1, 8, 2, 9, 3]
Parameters:
  • datasets (Sequence[IterDataset[T]]) – The datasets to mix.

  • weights (Sequence[float] | None) – The weights to use for mixing. Defaults to uniform weights if not specified.

Returns:

A dataset that represents a mixture of the input datasets according to the given weights.

Return type:

IterDataset[T]

class grain.IterDataset(parents=())#

Bases: _Dataset, Iterable[T]

Represents a dataset with transformations that support Iterable interface.

Transformations do not mutate the dataset object. Instead, they return a new dataset. IterDataset is immutable.

Parameters:

parents (MapDataset | IterDataset | Sequence[MapDataset | IterDataset])

__init__(parents=())#
Parameters:

parents (MapDataset | IterDataset | Sequence[MapDataset | IterDataset])

abstractmethod __iter__()#

Returns an iterator for this dataset.

Return type:

DatasetIterator[T]

batch(batch_size, *, drop_remainder=False, batch_fn=None)#

Returns a dataset of elements batched along a new first dimension.

Dataset elements are expected to be PyTrees.

Example usage:

ds = MapDataset.range(5).to_iter_dataset()
ds = ds.batch(batch_size=2)
list(ds) == [np.ndarray([0, 1]), np.ndarray([2, 3]), np.ndarray([4])]
Parameters:
  • batch_size (int) – The number of elements to batch together.

  • drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.

  • batch_fn (Callable[[Sequence[T]], S] | None) – A function that takes a list of elements and returns a batch. Defaults to stacking the elements along a new first batch dimension.

Returns:

A dataset of elements with the same PyTree structure with leaves concatenated along the first dimension.

Return type:

IterDataset[S]

filter(transform)#

Returns a dataset containing only the elements that match the filter.

Example usage:

ds = MapDataset.range(5).to_iter_dataset()
ds = ds.filter(lambda x: x % 2 == 0)
list(ds) == [0, 2, 4]

Produces a warning if the filter removes more than 90% of the elements. You can adjust the threshold through grain.experimental.DatasetOptions.filter_warn_threshold_ratio used in WithOptionsIterDataset. In order to produce an exception in such case use filter_raise_threshold_ratio.

Parameters:

transform (Filter | Callable[[T], bool]) – Either a FilterTransform containing the filter method or a callable that takes an element and returns a boolean.

Returns:

A dataset of the same type containing only the elements for which the filter transform returns True .

Return type:

IterDataset[T]

map(transform)#

Returns a dataset containing the elements transformed by transform.

Example usage:

ds = MapDataset.range(5).to_iter_dataset()
ds = ds.map(lambda x: x + 10)
list(ds) == [10, 11, 12, 13, 14]
Parameters:

transform (MapTransform | Callable[[T], S]) – Either a MapTransform containing the map method or a callable that takes an element and returns a new element.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

IterDataset[S]

map_with_index(transform)#

Returns a dataset of the elements transformed by the transform.

The transform is called with the index of the element in the dataset as the first argument and the element as the second argument.

Example usage:

ds = MapDataset.range(5).to_iter_dataset()
ds = ds.map(lambda i, x: (i, 2**x))
list(ds) == [(0, 1), (1, 2), (2, 4), (3, 8), (4, 16)]
Parameters:

transform (MapWithIndex | Callable[[int, T], S]) – Either a MapWithIndexTransform containing the map_with_index method or a callable that takes an index and an element and returns a new element.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

IterDataset[S]

mp_prefetch(options=None, worker_init_fn=None)#

Returns a dataset prefetching elements in multiple processes.

Each of the processes works on a slice of the dataset. The slicing happens after all MapDataset transformations (right before to_iter_dataset).

WARNING: If the dataset contains many-to-one transformations (such as filter) or stateful transformations (such as packing), output of mp_prefetch may change if num_workers is changed. However, it is still going to be determisitic. If you need elasticity in the number of prefetch workers, consider moving many-to-one and stateful transformations to after mp_prefetch or outside of the Grain pipeline.

Parameters:
  • options (MultiprocessingOptions | None) – options for the prefetching processes. options.num_workers must be greater than or equal to 0. If options.num_workers is 0, mp_prefetch has no effect. Defaults to MultiprocessingOptions(num_workers=10).

  • worker_init_fn (Callable[[int, int], None] | None) – A function that is called in each worker process before the data is processed. The function takes two arguments: the current worker index and the total worker count.

Returns:

A dataset prefetching input elements in separate processes.

Return type:

IterDataset[T]

property parents: Sequence[MapDataset | IterDataset]#
pipe(func, /, *args, **kwargs)#

Syntactic sugar for applying a callable to this dataset.

The pipe method, borrowed from pandas.DataFrame, is convenient because it allows for using method chaining syntax in an extensible fashion, with transformations that are not built-in methods on Dataset.

For example, suppose you want to shuffle a dataset within a window. Functionality for this is available in WindowShuffleMapDataset, but not as a method on MapDataset, e.g.:

dataset = (
    grain.experimental.WindowShuffleMapDataset(
        grain.MapDataset.range(1000),
        window_size=128,
        seed=0,
    )
    .batch(16)
)

This solution suffers from readability, because the shuffle transformation appears out of order from the data flow.

In contrast, with pipe you can write:

dataset = (
    grain.MapDataset.range(1000)
    .pipe(
        grain.experimental.WindowShuffleMapDataset,
        window_size=128,
        seed=0
    )
    .batch(16)
)
Parameters:
  • func (Callable[[...], T]) – The callable to apply to this dataset.

  • *args – Additional positional arguments to pass to the callable.

  • **kwargs – Keyword arguments to pass to the callable.

Returns:

The result of calling func(self, *args, **kwargs).

Return type:

T

prefetch(multiprocessing_options)#

Deprecated, use mp_prefetch instead.

Returns a dataset prefetching the elements in multiple processes.

Each of the processes will process a slice of the dataset after all MapDataset transformations.

WARNING: If the dataset contains many-to-one transformations (such as batch), output of prefetch may change if you change the number of workers. However, it is still going to be determisitic.

Parameters:

multiprocessing_options (MultiprocessingOptions) – options for the prefetching processes. num_workers must be greater than 0.

Returns:

A dataset prefetching input elements concurrently.

Return type:

IterDataset[T]

random_map(transform, *, seed=None)#

Returns a dataset containing the elements transformed by transform.

The transform is called with the element and a np.random.Generator instance that should be used inside the transform to preserve determinism. The seed can be either provided explicitly or set via ds.seed(seed). Prefer the latter if you don’t need to control the random map seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline. The geenrator is seeded by a combination of the seed and a counter of elements produced by the dataset.

NOTE: Avoid using the provided RNG outside of the transform function (e.g. by passing it to the next transformation along with the data). The RNG is going to be reused.

Example usage:

ds = MapDataset.range(5).to_iter_dataset()
ds = ds.random_map(lambda x, rng: x + rng.integers(5, 10))
set(ds).issubset(set(range(5, 15)))
Parameters:
  • transform (RandomMapTransform | Callable[[T, numpy.random.Generator], S]) – Either a RandomMapTransform containing the random_map method or a callable that takes an element and a np.random.Generator and returns a new element.

  • seed (int | None) – An integer between 0 and 2**32-1 representing the seed used to initialize the random number generator used by transform. If you don’t need to control the transformation seed individually, prefer setting the pipeline-level seed with ds.seed(seed) instead.

Returns:

A dataset containing the elements of the original dataset transformed by transform.

Return type:

IterDataset[S]

seed(seed)#

Returns a dataset that uses the seed for default seed generation.

When default seed generation is enabled by calling ds.seed, every downstream random transformation will be automatically seeded with a unique seed by default. This simplifies seed management, making it easier to avoid:

  • Having to provide a seed in multiple transformations.

  • Accidentally reusing the same seed across transformations.

It is recommended to call this right after the source. ds.seed has to be called before any random transformations (such as random_map that rely on default seed generation to control their seeding). Given the same seed, the pipeline is guaranteed to always use the same seeds for each transformation.

Note about custom dataset implementations: the default seed generation is available through _default_seed, but the private API is not guaranteed to be stable.

Example 1:

ds = ds.seed(seed).random_map(...)

random_map will automatically derive its own seed (different from seed).

Example 2:

ds = ds.seed(seed).random_map().random_map(...)

The first and second random_map s will each derive their own seed and the seeds are going to be different.

Example 3:

ds = ds.seed(seed).random_map(transform, seed=seed1)

random_map will use seed1 and will not be affected by seed. This can be used to control individual transformation seeding independently from the rest of the pipeline.

Example 4:

ds = ds.seed(seed1).random_map(...).seed(seed2).random_map(...)

ds.seed only affects the downstream transformations and can be overridden by a subsequent seed call. The first random_map will derive its seed from seed1, the second - from seed2 and will not be affected by seed1. This can be used to control your transformation seeding even if you don’t own the first part of the pipeline.

Example 5:

ds1 = ds.source(...).seed(seed2).shuffle().to_iter_dataset()
ds2 = ds.source(...).seed(seed2).shuffle().to_iter_dataset()
ds = IterDataset.mix([ds1, ds2], ...).random_map(...)

Each shuffle will derive its own seed from seed1 or seed2 respectively. random_map will derive its seed from both seed1 and seed2.

Parameters:

seed (int) – Seed to use.

Returns:

A dataset with elements unchanged.

Return type:

IterDataset[T]

class grain.DatasetIterator(parents=())#

Bases: Iterator[T], ABC

IterDataset iterator.

NOTE: The methods are assumed to be thread-unsafe. Please ensure only a single thread can access a DatasetIterator instance.

Parameters:

parents (DatasetIterator | Sequence[DatasetIterator])

__init__(parents=())#
Parameters:

parents (DatasetIterator | Sequence[DatasetIterator])

__iter__()#
Return type:

DatasetIterator[T]

abstractmethod get_state()#

Returns the current state of the iterator.

We reserve the right to evolve the state format over time. The states returned from this method are only guaranteed to be restorable by the same version of the code that produced them.

Implementation Note: It is recommended that iterator implementations always produce states with the same shapes and types throughout the lifetime of the iterator. Some frameworks rely on this property to perform checkpointing, and all standard library iterators are compliant. It is also recommended to produce state values that support shapes and types, e.g. using np.int64 instead of int. The standard library iterators are not currently compliant with this recommendation.

Return type:

dict[str, Any]

async load(directory)#

Loads the iterator state from a directory.

The state may be loaded and set in a background thread. The main thread should not alter the state content while the load is in progress.

Parameters:

directory (Path) – The directory to load the state from.

Returns:

A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.

Return type:

Awaitable[None]

async save(directory)#

Saves the iterator state to a directory.

The current state (get_state) is used for saving, so any updates to the state after returning from this method will not affect the saved checkpoint.

Parameters:

directory (PathAwaitingCreation) – A path in the process of being created. Must call await_creation before accessing the physical path.

Returns:

A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.

Return type:

Awaitable[None]

abstractmethod set_state(state)#

Sets the current state of the iterator.

Parameters:

state (dict[str, Any])

start_prefetch()#

Asynchronously starts processing and buffering elements.

NOTE: Only available on iterators of asynchronous transformations.

Can be useful when the iterator can be created in advance but the elements are not needed immediately. For instance, when recovering iterator and model from a checkpoint, recover the iterator first, call start_prefech and then recover the model. This way the time to get the first batch from the iterator will be partially or fully hidden behind the time it takes to recover the model.

Return type:

None