Transformations#

Grain Transforms interface denotes transformations which are applied to data. In the case of local transformations (such as map, random map, filter), the transforms receive an element on which custom changes are applied. For global transformations (such as batching), one must provide the batch size.

The Grain core transforms interface code is here.

Map Transform#

Map Transform is for 1:1 transformations of elements. Elements can be of any type, it is the user’s responsibility to use the transformation such that the inputs it receives correspond to the signature.

Example of transformation which implements Map Transform (for elements of type int):

class PlusOne(grain.transforms.Map):

  def map(self, x: int) -> int:
    return x + 1

MapWithIndex Transform#

MapWithIndex Transform is similar to Map transform in being a 1:1 transformations of elements, but also takes in the index/position of the element as the first argument. This is useful for pairing elements with an index key or even keeping it as metadata alongside the actual data.

Example of transformation which implements MapWithIndex transform (for elements of type int):

class PlusOneWithIndexKey(grain.transforms.MapWithIndex):

  def map_with_index(self, i: int, x: int) -> tuple[int, int]:
    return (x + 1, i)

RandomMap Transform#

RandomMap Transform is for 1:1 random transformations of elements. The interface requires a np.random.Generator as parameter to the random_map function.

Example of a RandomMap Transform:

class PlusRandom(grain.transforms.RandomMap):

  def random_map(self, x: int, rng: np.random.Generator) -> int:
    return x + rng.integers(100_000)

FlatMap Transform#

FlatMap Transform is for splitting operations of individual elements. The max_fan_out is the maximum number of splits that an element can generate. Please consult the code for detailed info.

Example of a FlatMap Transform:

class FlatMapTransformExample(grain.experimental.FlatMapTransform):
  max_fan_out: int

  def flat_map(self, element: int):
    for _ in range(self.max_fan_out):
      yield element

Filter Transform#

Filter Transform is for applying filtering to individual elements. Elements for which the filter function returns False will be removed.

Example of a Filter Transform that removes all even elements:

class RemoveEvenElements(grain.transforms.Filter):

  def filter(self, element: int) -> bool:
    return element % 2

Batch#

To apply the Batch transform, pass grain.transforms.Batch(batch_size=batch_size, drop_remainder=drop_remainder).

Note: The batch size used when passing Batch transform will be the global batch size if it is done before sharding and the per host batch size if it is after. Typically usage with IndexSampler is after sharding.