grain
Dataset#
Dataset classes.
MapDataset
defines a dataset that supports efficient random access. It has 3 important methods:__len__()
returns the length of a single epoch over the dataset.__getitem__()
returns an element at the given positive index. The “true” length of aMapDataset
is infinite.Individual dataset elements are only evaluated when calling
__getitem__()
.MapDataset
s are stateless and will not hold elements.
IterDataset
defines a dataset that does not support efficient random access but can be iterated over. AMapDataset
can be turned into aIterDataset
but going fromIterDataset
toMapDataset
is as expensive as materializing the whole dataset.DatasetIterator
defines a stateful iterator ofIterDataset
. The state of the iterator can be saved and restored.
Using the interfaces defined in collections.abc
you can think of
MapDataset
as (infinite) Sequence
, IterDataset
as Iterable
and
DatasetIterator
as Iterator
.
MapDataset
is typically created by one of the factory methods in
MapDatasetMeta
(e.g. MapDataset.range(5)
). IterDataset
is either
created by calling to_iter_dataset()
on a MapDataset
or by one of the
factory methods in IterDatasetMeta
(e.g. IterDataset.mix([...])
).
List of Members#
- class _src.python.dataset.dataset.MapDatasetMeta(name, bases, namespace, /, **kwargs)#
Metaclass for
MapDataset
containing factory transformations.- concatenate(datasets)#
Returns a dataset of elements from all input datasets.
Example usage:
ds1 = MapDataset.range(3) ds2 = MapDataset.range(3, 8) list(MapDataset.concatenate([ds1, ds2])) == [0, 1, 2, 3, 4, 5, 6, 7]
- Parameters:
datasets (Sequence[MapDataset[T]]) – The datasets to concatenate.
- Returns:
A
MapDataset
that represents a concatenation of the input datasets. The n-th epoch of the returned dataset will be the n-th epoch of the component datasets.- Return type:
MapDataset[T]
- mix(datasets, weights=None)#
Returns a dataset that mixes input datasets with the given weights.
Length of the mixed dataset will be determined by the length of the shortest input dataset. If you need an infinite dateset consider repeating the input datasets before mixing.
If you need to shuffle the mixed dataset while preserving the correct proportions, you should shuffle the input datasets before mixing.
Example usage:
ds1 = MapDataset.range(5) ds2 = MapDataset.range(7, 10) list(MapDataset.mix([ds1, ds2])) == [0, 7, 1, 8, 2, 9]
- Parameters:
datasets (Sequence[MapDataset[T]]) – The datasets to mix.
weights (Sequence[float] | None) – The weights to use for mixing. Defaults to uniform weights if not specified.
- Returns:
A MapDataset that represents a mixture of the input datasets according to the given weights.
- Return type:
MapDataset[T]
- range(start, stop=None, step=1)#
Returns a dataset with a range of integers.
Input arguments are interpreted the same way as in Python built-in
range
:range(n)
=> start=0, stop=n, step=1range(m, n)
=> start=m, stop=n, step=1range(m, n, p)
=> start=m, stop=n, step=p
The produced values are consistent with the built-in range function:
list(MapDataset.range(...)) == list(range(...))
- Parameters:
start (int) – The start of the range.
stop (int | None) – The stop of the range.
step (int) – The step of the range.
- Returns:
A MapDataset with a range of integers.
- Return type:
MapDataset[int]
- select_from_datasets(datasets, selection_map)#
Returns a dataset selected from the inputs accoridng to the given map.
Allows more general types of dataset mixing than
mix
.- Parameters:
datasets (Sequence[MapDataset[T]]) – The datasets to select from.
selection_map (DatasetSelectionMap) – Mapping from index within the mixed dataset to a selected dataset index and index within that dataset. Length of the resulting dataset will be determined by the length of the
selection_map
.
- Returns:
A MapDataset that represents a mixture of the input datasets according to the given selection map.
- Return type:
MapDataset[T]
- source(source)#
Returns a dataset that wraps a data source supporting random access.
Example:
ds = MapDataset.source(ArrayRecordDataSource(paths))
Works with Sequence inputs as well:
list(MapDataset.source([1, 2, 3, 4, 5])) == [1, 2, 3, 4, 5]
- Parameters:
source (Sequence[T] | RandomAccessDataSource[T]) – Data source supporting efficient random access.
- Returns:
A MapDataset that wraps the data source and allows to chain other MapDataset transformations.
- Return type:
MapDataset[T]
- class grain.MapDataset(parents=())#
Bases:
_Dataset
,Generic
[T
]Represents a dataset with transformations that support random access.
Transformations do not mutate the dataset object. Instead, they return a new dataset.
MapDataset
is immutable.Note
MapDataset
transformations such as.filter()
useNone
to indicate absence of an element. Generally, the implementation ofMapDataset
transformations already handle None as a special case (e.g. by returningNone
as soon as__getitem__
seesNone
). This means the user-defined functions passed to theMapDataset
transformations do not need to explicitly handleNone
s.- Parameters:
parents (MapDataset | Sequence[MapDataset])
- abstractmethod __getitem__(index: slice) MapDataset[T] #
- abstractmethod __getitem__(index: int) T | None
Returns the element for the index or None if missing.
- __init__(parents=())#
- Parameters:
parents (MapDataset | Sequence[MapDataset])
- __iter__()#
- Return type:
- abstractmethod __len__()#
Returns length of a single epoch of this dataset.
- Return type:
int
- batch(batch_size, *, drop_remainder=False, batch_fn=None)#
Returns a dataset of elements batched along a new first dimension.
Dataset elements are expected to be PyTrees.
Example usage:
ds = MapDataset.range(5) ds = ds.batch(batch_size=2) list(ds) == [np.ndarray([0, 1]), np.ndarray([2, 3]), np.ndarray([4])]
- Parameters:
batch_size (int) – The number of elements to batch together.
drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.
batch_fn (Callable[[Sequence[T]], S] | None) – A function that takes a list of elements and returns a batch. Defaults to stacking the elements along a new first batch dimension.
- Returns:
A dataset of elements with the same PyTree structure with leaves concatenated along the first dimension.
- Return type:
MapDataset[S]
- filter(transform)#
Returns a dataset containing only the elements that match the filter.
Accessing an element of the returned dataset using subscription (
ds[i]
) returns:None
iftransform
returnedFalse
the element if
transform
returnedTrue
Iterating over a filtered dataset skips
None
elements by default.Example usage:
ds = MapDataset.range(5) ds = ds.filter(lambda x: x % 2 == 0) ds[2] == 2 ds[1] == None list(ds) == [0, 2, 4]
NOTE:
list(ds)
converts the dataset to anIterDataset
withto_iter_dataset()
under the hood which by default skipsNone
elements.to_iter_dataset
produces a warning when iterating through a filtered dataset if the filter removes more than 90% of the elements. You can adjust the threshold throughgrain.experimental.DatasetOptions.filter_warn_threshold_ratio
used inWithOptionsIterDataset
. In order to produce an exception in such case usefilter_raise_threshold_ratio
.- Parameters:
transform (Filter | Callable[[T], bool]) – Either a
FilterTransform
containing thefilter
method or a callable that takes an element and returns a boolean.- Returns:
A dataset of the same type containing only the elements for which the filter transform returns
True
.- Return type:
MapDataset[T]
- map(transform)#
Returns a dataset containing the elements transformed by
transform
.Example usage:
ds = MapDataset.range(5) ds = ds.map(lambda x: x + 10) list(ds) == [10, 11, 12, 13, 14]
- Parameters:
transform (MapTransform | Callable[[T], S]) – Either a
MapTransform
containing themap
method or a callable that takes an element and returns a new element.- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
MapDataset[S]
- map_with_index(transform)#
Returns a dataset containing the elements transformed by
transform
.The transform is called with the index of the element within the dataset and the element itself.
Example usage:
ds = MapDataset.source(["a", "b", "c", "d"]) ds = ds.map_with_index(lambda i, x: x + str(i)) list(ds) == ["a0", "b1", "c2", "d3"]
- Parameters:
transform (MapWithIndex | Callable[[int, T], S]) – Either a
MapWithIndexTransform
containing themap_with_index
method or a callable that takes an index and an element and returns a new element.- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
MapDataset[S]
- property parents: Sequence[MapDataset]#
- pipe(func, /, *args, **kwargs)#
Syntactic sugar for applying a callable to this dataset.
The
pipe
method, borrowed frompandas.DataFrame
, is convenient because it allows for using method chaining syntax in an extensible fashion, with transformations that are not built-in methods onDataset
.For example, suppose you want to shuffle a dataset within a window. Functionality for this is available in
WindowShuffleMapDataset
, but not as a method onMapDataset
, e.g.:dataset = ( grain.experimental.WindowShuffleMapDataset( grain.MapDataset.range(1000), window_size=128, seed=0, ) .batch(16) )
This solution suffers from readability, because the shuffle transformation appears out of order from the data flow.
In contrast, with
pipe
you can write:dataset = ( grain.MapDataset.range(1000) .pipe( grain.experimental.WindowShuffleMapDataset, window_size=128, seed=0 ) .batch(16) )
- Parameters:
func (Callable[[...], T]) – The callable to apply to this dataset.
*args – Additional positional arguments to pass to the callable.
**kwargs – Keyword arguments to pass to the callable.
- Returns:
The result of calling
func(self, *args, **kwargs)
.- Return type:
T
- random_map(transform, *, seed=None)#
Returns a dataset containing the elements transformed by
transform
.The
transform
is called with the element and anp.random.Generator
instance that should be used inside thetransform
to preserve determinism. The seed can be either provided explicitly or set viads.seed(seed)
. Prefer the latter if you don’t need to control the random map seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline. The generator is seeded by a combination of the seed and the index of the element in the dataset.NOTE: Avoid using the provided RNG outside of the
transform
function (e.g. by passing it to the next transformation along with the data). The RNG is going to be reused.Example usage:
ds = MapDataset.range(5) ds = ds.random_map(lambda x, rng: x + rng.integers(5, 10)) set(ds).issubset(set(range(5, 15)))
- Parameters:
transform (RandomMapTransform | Callable[[T, numpy.random.Generator], S]) – Either a
RandomMapTransform
containing therandom_map
method or a callable that takes an element and a np.random.Generator and returns a new element.seed (int | None) – An optional integer between 0 and 2**32-1 representing the seed used to initialize the random number generator used by
transform
. If you don’t need to control the shuffle seed individually, prefer setting the pipeline-level seed with``ds.seed(seed)`` instead.
- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
MapDataset[S]
- repeat(num_epochs=None)#
Returns a dataset repeating the elements of this dataset multiple times.
Specifying
None
fornum_epochs
will repeat the dataset infinitely, and causeslen(ds)
to returnsys.maxsize
.Since
MapDataset
allows accessing elements pastlen(ds) - 1
anyway (and uses the index modulolen(ds)
), this transformation effectively only changes the length of the dataset.Can not be called on an infinite dataset.
Example usage:
list(MapDataset.range(5).repeat(2)) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] ds = MapDataset.range(5).repeat() len(ds) == sys.maxsize ds[11111] == 1
- Parameters:
num_epochs (int | None) – Either a positive integer representing the number of times this dataset should be repeated or
None
to repeat infinitely.- Returns:
A dataset repeating the elements of this dataset multiple times.
- Return type:
MapDataset[T]
- seed(seed)#
Returns a dataset that uses the seed for default seed generation.
When default seed generation is enabled by calling
ds.seed
, every downstream random transformation will be automatically seeded with a unique seed by default. This simplifies seed management, making it easier to avoid:Having to provide a seed in multiple transformations.
Accidentally reusing the same seed across transformations.
It is recommended to call this right after the source.
ds.seed
has to be called before any random transformations (such asshuffle
orrandom_map
that rely on default seed generation to control their seeding). Given the same seed, the pipeline is guaranteed to always use the same seeds for each transformation.WARNING: The seed for random downstream transformations is derived from the seed passed to
ds.seed
and the absolute position of the transformation in the pipeline. This means that if you add transformations before the random transformation, its seed will change. For instance, if this random transformation isshuffle
, adding a transformation beforeshuffle
will change its seed and data order, consequently. To avoid this, pass the seed to the transformation directly.Note about custom dataset implementations: the default seed generation is available through
_default_seed
, but the private API is not guaranteed to be stable.Example 1:
ds = ds.seed(seed).shuffle()
shuffle
will automatically derive its own seed (different fromseed
).Example 2:
ds = ds.seed(seed).shuffle().random_map(...)
shuffle
andrandom_map
will each derive their own seed and the seeds are going to be different.Example 3:
ds = ds.seed(seed).random_map(transform, seed=seed1)
random_map
will useseed1
and will not be affected byseed
. This can be used to control individual transformation seeding independently from the rest of the pipeline.Example 4:
ds = ds.seed(seed1).shuffle().seed(seed2).random_map(...)
ds.seed
only affects the downstream transformations and can be overridden by a subsequentseed
call.shuffle
will derive its seed fromseed1
,random_map
- fromseed2
and will not be affected byseed1
. This can be used to control your transformation seeding even if you don’t own the first part of the pipeline.Example 5:
ds1 = ds.source(...).seed(seed1).shuffle() ds2 = ds.source(...).seed(seed2).shuffle() ds = MapDataset.mix([ds1, ds2], ...).random_map(...)
Each
shuffle
will derive its own seed fromseed1
orseed2
respectively.random_map
will derive its seed from bothseed1
andseed2
.- Parameters:
seed (int) – Seed to use.
- Returns:
A dataset with elements unchanged.
- Return type:
MapDataset[T]
- shuffle(seed=None)#
Returns a dataset with the same elements in a globally shuffled order.
The shuffle is deterministic and will always produce the same result given the same seed. The seed can be either provided explicitly or set via
ds.seed(seed)
. Prefer the latter if you don’t need to control the shuffle seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline.In multi-epoch training each epoch will be shuffled differently (i.e. the seed is combined with epoch number). In such case it is recommended to
shuffle
beforerepeat
to avoid mixing elements from different epochs.Example usage:
ds = MapDataset.range(5).shuffle() set(ds) == {0, 1, 2, 3, 4} list(ds) != [0, 1, 2, 3, 4] # With probability (1 - 1/5!).
- Parameters:
seed (int | None) – An optional integer between 0 and 2**32-1 representing the seed used by the shuffling algorithm. If you don’t need to control the shuffle seed individually, prefer setting the pipeline-level seed with
ds.seed(seed)
instead.- Returns:
A dataset containing the same elements but in a shuffled order.
- Return type:
MapDataset[T]
- slice(sl)#
Returns a dataset containing only the elements with indices in
sl
.For most implementations of
MapDataset
slicing is also available through subscript operator:list(ds.slice(slice(1, 10, 2))) == ds[1:10:2]
.Example usage:
ds = MapDataset.range(5) list(ds.slice(slice(1, 3))) == [1, 2] list(ds.slice(slice(1, None, 2))) == [1, 3]
Commonly used for sharding:
ds = ds.slice(slice(shard_index, None, shard_count))
, or, equivalently,ds = ds[shard_index::shard_count]
.- Parameters:
sl (slice) – A
slice
object (https://docs.python.org/3/library/functions.html#slice) representing the slice of elements to that should constitute the returned dataset.- Returns:
A dataset containing only the elements with indices in the
sl
slice.- Return type:
MapDataset[T]
- to_iter_dataset(read_options=None, *, allow_nones=False)#
Converts this dataset to an
IterDataset
.Elements from this dataset may be processed in multiple threads.
Note that some of the transformations are not available on
IterDataset
. These are roughly transformations operating on element index such asshuffle
,map_with_index
,slice
andrepeat
.- Parameters:
read_options (ReadOptions | None) – Controls multithreading when reading the data and applying transformations in this dataset.
allow_nones (bool) – Whether to allow
None
values in the dataset (e.g. produced byfilter
). IfFalse
(the default),None
values will be filtered out.
- Returns:
An
IterDataset
with the same non-None
elements as this dataset.- Return type:
IterDataset[T]
- class _src.python.dataset.dataset.IterDatasetMeta(name, bases, namespace, /, **kwargs)#
Metaclass for
IterDataset
containing factory transformations.- mix(datasets, weights=None)#
Returns a dataset that mixes input datasets with the given weights.
NOTE: Stops producing elements once any input dataset is exhausted. If you need an infinite mixed dateset consider repeating the input datasets before mixing.
Example usage:
ds1 = MapDataset.range(5).to_iter_dataset() ds2 = MapDataset.range(7, 10).to_iter_dataset() list(IterDataset.mix([ds1, ds2])) == [0, 7, 1, 8, 2, 9, 3]
- Parameters:
datasets (Sequence[IterDataset[T]]) – The datasets to mix.
weights (Sequence[float] | None) – The weights to use for mixing. Defaults to uniform weights if not specified.
- Returns:
A dataset that represents a mixture of the input datasets according to the given weights.
- Return type:
IterDataset[T]
- class grain.IterDataset(parents=())#
Bases:
_Dataset
,Iterable
[T
]Represents a dataset with transformations that support Iterable interface.
Transformations do not mutate the dataset object. Instead, they return a new dataset.
IterDataset
is immutable.- Parameters:
parents (MapDataset | IterDataset | Sequence[MapDataset | IterDataset])
- __init__(parents=())#
- Parameters:
parents (MapDataset | IterDataset | Sequence[MapDataset | IterDataset])
- abstractmethod __iter__()#
Returns an iterator for this dataset.
- Return type:
- batch(batch_size, *, drop_remainder=False, batch_fn=None)#
Returns a dataset of elements batched along a new first dimension.
Dataset elements are expected to be PyTrees.
Example usage:
ds = MapDataset.range(5).to_iter_dataset() ds = ds.batch(batch_size=2) list(ds) == [np.ndarray([0, 1]), np.ndarray([2, 3]), np.ndarray([4])]
- Parameters:
batch_size (int) – The number of elements to batch together.
drop_remainder (bool) – Whether to drop the last batch if it is smaller than batch_size.
batch_fn (Callable[[Sequence[T]], S] | None) – A function that takes a list of elements and returns a batch. Defaults to stacking the elements along a new first batch dimension.
- Returns:
A dataset of elements with the same PyTree structure with leaves concatenated along the first dimension.
- Return type:
IterDataset[S]
- filter(transform)#
Returns a dataset containing only the elements that match the filter.
Example usage:
ds = MapDataset.range(5).to_iter_dataset() ds = ds.filter(lambda x: x % 2 == 0) list(ds) == [0, 2, 4]
Produces a warning if the filter removes more than 90% of the elements. You can adjust the threshold through
grain.experimental.DatasetOptions.filter_warn_threshold_ratio
used inWithOptionsIterDataset
. In order to produce an exception in such case usefilter_raise_threshold_ratio
.- Parameters:
transform (Filter | Callable[[T], bool]) – Either a
FilterTransform
containing thefilter
method or a callable that takes an element and returns a boolean.- Returns:
A dataset of the same type containing only the elements for which the filter transform returns
True
.- Return type:
IterDataset[T]
- map(transform)#
Returns a dataset containing the elements transformed by
transform
.Example usage:
ds = MapDataset.range(5).to_iter_dataset() ds = ds.map(lambda x: x + 10) list(ds) == [10, 11, 12, 13, 14]
- Parameters:
transform (MapTransform | Callable[[T], S]) – Either a
MapTransform
containing themap
method or a callable that takes an element and returns a new element.- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
IterDataset[S]
- map_with_index(transform)#
Returns a dataset of the elements transformed by the
transform
.The
transform
is called with the index of the element in the dataset as the first argument and the element as the second argument.Example usage:
ds = MapDataset.range(5).to_iter_dataset() ds = ds.map(lambda i, x: (i, 2**x)) list(ds) == [(0, 1), (1, 2), (2, 4), (3, 8), (4, 16)]
- Parameters:
transform (MapWithIndex | Callable[[int, T], S]) – Either a
MapWithIndexTransform
containing themap_with_index
method or a callable that takes an index and an element and returns a new element.- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
IterDataset[S]
- mp_prefetch(options=None, worker_init_fn=None)#
Returns a dataset prefetching elements in multiple processes.
Each of the processes works on a slice of the dataset. The slicing happens after all
MapDataset
transformations (right beforeto_iter_dataset
).WARNING: If the dataset contains many-to-one transformations (such as
filter
) or stateful transformations (such as packing), output ofmp_prefetch
may change ifnum_workers
is changed. However, it is still going to be determisitic. If you need elasticity in the number of prefetch workers, consider moving many-to-one and stateful transformations to aftermp_prefetch
or outside of the Grain pipeline.- Parameters:
options (MultiprocessingOptions | None) – options for the prefetching processes.
options.num_workers
must be greater than or equal to 0. Ifoptions.num_workers
is 0,mp_prefetch
has no effect. Defaults toMultiprocessingOptions(num_workers=10)
.worker_init_fn (Callable[[int, int], None] | None) – A function that is called in each worker process before the data is processed. The function takes two arguments: the current worker index and the total worker count.
- Returns:
A dataset prefetching input elements in separate processes.
- Return type:
IterDataset[T]
- property parents: Sequence[MapDataset | IterDataset]#
- pipe(func, /, *args, **kwargs)#
Syntactic sugar for applying a callable to this dataset.
The
pipe
method, borrowed frompandas.DataFrame
, is convenient because it allows for using method chaining syntax in an extensible fashion, with transformations that are not built-in methods onDataset
.For example, suppose you want to shuffle a dataset within a window. Functionality for this is available in
WindowShuffleMapDataset
, but not as a method onMapDataset
, e.g.:dataset = ( grain.experimental.WindowShuffleMapDataset( grain.MapDataset.range(1000), window_size=128, seed=0, ) .batch(16) )
This solution suffers from readability, because the shuffle transformation appears out of order from the data flow.
In contrast, with
pipe
you can write:dataset = ( grain.MapDataset.range(1000) .pipe( grain.experimental.WindowShuffleMapDataset, window_size=128, seed=0 ) .batch(16) )
- Parameters:
func (Callable[[...], T]) – The callable to apply to this dataset.
*args – Additional positional arguments to pass to the callable.
**kwargs – Keyword arguments to pass to the callable.
- Returns:
The result of calling
func(self, *args, **kwargs)
.- Return type:
T
- prefetch(multiprocessing_options)#
Deprecated, use
mp_prefetch
instead.Returns a dataset prefetching the elements in multiple processes.
Each of the processes will process a slice of the dataset after all
MapDataset
transformations.WARNING: If the dataset contains many-to-one transformations (such as
batch
), output of prefetch may change if you change the number of workers. However, it is still going to be determisitic.- Parameters:
multiprocessing_options (MultiprocessingOptions) – options for the prefetching processes.
num_workers
must be greater than 0.- Returns:
A dataset prefetching input elements concurrently.
- Return type:
IterDataset[T]
- random_map(transform, *, seed=None)#
Returns a dataset containing the elements transformed by
transform
.The
transform
is called with the element and anp.random.Generator
instance that should be used inside thetransform
to preserve determinism. The seed can be either provided explicitly or set viads.seed(seed)
. Prefer the latter if you don’t need to control the random map seed individually. It allows to pass a single seed to derive seeds for all downstream random transformations in the pipeline. The geenrator is seeded by a combination of the seed and a counter of elements produced by the dataset.NOTE: Avoid using the provided RNG outside of the
transform
function (e.g. by passing it to the next transformation along with the data). The RNG is going to be reused.Example usage:
ds = MapDataset.range(5).to_iter_dataset() ds = ds.random_map(lambda x, rng: x + rng.integers(5, 10)) set(ds).issubset(set(range(5, 15)))
- Parameters:
transform (RandomMapTransform | Callable[[T, numpy.random.Generator], S]) – Either a
RandomMapTransform
containing therandom_map
method or a callable that takes an element and a np.random.Generator and returns a new element.seed (int | None) – An integer between 0 and 2**32-1 representing the seed used to initialize the random number generator used by
transform
. If you don’t need to control the transformation seed individually, prefer setting the pipeline-level seed withds.seed(seed)
instead.
- Returns:
A dataset containing the elements of the original dataset transformed by
transform
.- Return type:
IterDataset[S]
- seed(seed)#
Returns a dataset that uses the seed for default seed generation.
When default seed generation is enabled by calling
ds.seed
, every downstream random transformation will be automatically seeded with a unique seed by default. This simplifies seed management, making it easier to avoid:Having to provide a seed in multiple transformations.
Accidentally reusing the same seed across transformations.
It is recommended to call this right after the source.
ds.seed
has to be called before any random transformations (such asrandom_map
that rely on default seed generation to control their seeding). Given the same seed, the pipeline is guaranteed to always use the same seeds for each transformation.Note about custom dataset implementations: the default seed generation is available through
_default_seed
, but the private API is not guaranteed to be stable.Example 1:
ds = ds.seed(seed).random_map(...)
random_map
will automatically derive its own seed (different fromseed
).Example 2:
ds = ds.seed(seed).random_map().random_map(...)
The first and second
random_map
s will each derive their own seed and the seeds are going to be different.Example 3:
ds = ds.seed(seed).random_map(transform, seed=seed1)
random_map
will useseed1
and will not be affected byseed
. This can be used to control individual transformation seeding independently from the rest of the pipeline.Example 4:
ds = ds.seed(seed1).random_map(...).seed(seed2).random_map(...)
ds.seed
only affects the downstream transformations and can be overridden by a subsequentseed
call. The firstrandom_map
will derive its seed fromseed1
, the second - fromseed2
and will not be affected byseed1
. This can be used to control your transformation seeding even if you don’t own the first part of the pipeline.Example 5:
ds1 = ds.source(...).seed(seed2).shuffle().to_iter_dataset() ds2 = ds.source(...).seed(seed2).shuffle().to_iter_dataset() ds = IterDataset.mix([ds1, ds2], ...).random_map(...)
Each
shuffle
will derive its own seed fromseed1
orseed2
respectively.random_map
will derive its seed from bothseed1
andseed2
.- Parameters:
seed (int) – Seed to use.
- Returns:
A dataset with elements unchanged.
- Return type:
IterDataset[T]
- class grain.DatasetIterator(parents=())#
Bases:
Iterator
[T
],ABC
IterDataset
iterator.NOTE: The methods are assumed to be thread-unsafe. Please ensure only a single thread can access a
DatasetIterator
instance.- Parameters:
parents (DatasetIterator | Sequence[DatasetIterator])
- __init__(parents=())#
- Parameters:
parents (DatasetIterator | Sequence[DatasetIterator])
- __iter__()#
- Return type:
- abstractmethod get_state()#
Returns the current state of the iterator.
We reserve the right to evolve the state format over time. The states returned from this method are only guaranteed to be restorable by the same version of the code that produced them.
Implementation Note: It is recommended that iterator implementations always produce states with the same shapes and types throughout the lifetime of the iterator. Some frameworks rely on this property to perform checkpointing, and all standard library iterators are compliant. It is also recommended to produce state values that support shapes and types, e.g. using
np.int64
instead ofint
. The standard library iterators are not currently compliant with this recommendation.- Return type:
dict[str, Any]
- async load(directory)#
Loads the iterator state from a directory.
The state may be loaded and set in a background thread. The main thread should not alter the state content while the load is in progress.
- Parameters:
directory (Path) – The directory to load the state from.
- Returns:
A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.
- Return type:
Awaitable[None]
- async save(directory)#
Saves the iterator state to a directory.
The current state (get_state) is used for saving, so any updates to the state after returning from this method will not affect the saved checkpoint.
- Parameters:
directory (PathAwaitingCreation) – A path in the process of being created. Must call await_creation before accessing the physical path.
- Returns:
A coroutine that has not been awaited. This is called by Orbax in a background thread to perform I/O without blocking the main thread.
- Return type:
Awaitable[None]
- abstractmethod set_state(state)#
Sets the current state of the iterator.
- Parameters:
state (dict[str, Any])
- start_prefetch()#
Asynchronously starts processing and buffering elements.
NOTE: Only available on iterators of asynchronous transformations.
Can be useful when the iterator can be created in advance but the elements are not needed immediately. For instance, when recovering iterator and model from a checkpoint, recover the iterator first, call
start_prefech
and then recover the model. This way the time to get the first batch from the iterator will be partially or fully hidden behind the time it takes to recover the model.- Return type:
None