grain.sharding module

Contents

grain.sharding module#

APIs for sharding pipelines for distributed training.

List of Members#

NoSharding()

Doesn't shard data.

ShardByJaxProcess([drop_remainder])

Shards the data across JAX processes.

ShardOptions(shard_index, shard_count[, ...])

Dataclass to hold options for sharding a data source.