Samplers#

Samplers in Grain are responsible for determining the order in which records are processed. This allows Grain to implement global transformations (e.g. global shuffling, sharding, repeating for multiple epochs) before reading any records.

Samplers need to implement the following iterator protocol:

class Sampler(Protocol):

  def __iter__(self):
    ...

  def __next__(self) -> record.RecordMetadata:
    ...

@dataclasses.dataclass
class RecordMetadata:
  """RecordMetadata contains metadata about individual records.

  RecordMetadata objects are emitted by the sampler to refer to which record to
  read next (record_key), what its index is (for keeping progress and
  checkpointing) as well as having an optional rng for stateless random
  transformations. In addition, they are also used to keep information about
  records as they flow through the pipeline from one operation to the other.
  """
  index: int
  record_key: Optional[int] = None
  rng: Optional[np.random.Generator] = None

Index Sampler#

This is our recommended Sampler. It supports:

  • Sharding across multiple machines (shard_options parameter).

  • Global shuffle of the data (shuffle parameter).

  • Repeating records for multiple epochs (num_epochs parameter). Note that the shuffle order changes across epochs. Behind the scenes, this relies on tf.random_index_shuffle.

  • Stateless random operations. Each RecordMetadata object emitted by the IndexSampler contains an RNG uniquely seeded on a per-record basis. This RNG can be used for random augmentations while not relying on a global state.

index_sampler = grain.samplers.IndexSampler(
  num_records=5,
  num_epochs=2,
  shard_options=grain.sharding.ShardOptions(shard_index=0, shard_count=1, drop_remainder=True),
  shuffle=True,
  seed=0)
for record_metadata in index_sampler:
  print(record_metadata)

# Output
# RecordMetadata(index=0, record_key=0, rng=Generator(Philox) at 0x7FB09947AF80)
# RecordMetadata(index=1, record_key=4, rng=Generator(Philox) at 0x7FB0994789E0)
# RecordMetadata(index=2, record_key=2, rng=Generator(Philox) at 0x7FB099478740)
# RecordMetadata(index=3, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0)
# RecordMetadata(index=4, record_key=1, rng=Generator(Philox) at 0x7FB099478740)
# RecordMetadata(index=5, record_key=1, rng=Generator(Philox) at 0x7FB0994789E0)
# RecordMetadata(index=6, record_key=0, rng=Generator(Philox) at 0x7FB099478740)
# RecordMetadata(index=7, record_key=3, rng=Generator(Philox) at 0x7FB0994789E0)
# RecordMetadata(index=8, record_key=4, rng=Generator(Philox) at 0x7FB099478740)
# RecordMetadata(index=9, record_key=2, rng=Generator(Philox) at 0x7FB0994789E0)

Implement your Own Sampler#

Grain can accommodate custom user-defined samplers. Users implementing their own sampler should ensure it:

  • implements the aforementioned interface.

  • is adequately performant. Since Grain’s DataLoader iterates sequentially through the sampler to distribute indices to child processes, a slow sampler will become a bottleneck and reduce end-to-end pipeline performance. As a reference, we recommend sampler iteration performance of at approx. 50,000 elements / sec for most use cases.