grain.DataLoaderIterator#
- class grain.DataLoaderIterator(data_loader, state)#
DataLoader iterator providing get/set state functionality.
This is the only iterator we expose to users. It wraps underlying MultipleProcessIterator. In order to set state, it recreates the underlying iterator fresh with a new state.
Checkpointing for DataLoaderIterator: DataLoaderIterator uses GrainPool, which distributes RecordMetadata from produced records among worker processes in a round robin fashion. Generally, some workers can process more elements than others at a given training step. Checkpointing logic goes as follows: 1) With each output batch produced, GrainPool emits the worker_index of The
worker that processed the batch.
DataLoaderIterator keeps track of the last_seen_index at each worker.
When restoring from a state, DataLoaderIterator checks what is the minimum last_seen_index (among the last seen indices for all workers.) and which worker processed that index. GrainPool is instructed to start distributing indices to the next worker.
- Parameters:
data_loader (DataLoader)
state (_IteratorState)
- __init__(data_loader, state)#
- Parameters:
data_loader (DataLoader)
state (dict[str, Any])
Methods
__init__
(data_loader, state)get_state
()load
(directory)Loads the iterator state from a directory.
save
(directory)Saves the iterator state to a directory.
set_state
(state)Sets the state for the underlying iterator.