Fused GEMM + CCL Operations#

Fused matrix multiplication and collective communication operations accessible via the ops property on the Iris instance (e.g. ctx.ops.matmul_all_reduce(...)).

matmul_all_reduce#

OpsNamespace.matmul_all_reduce(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#

Fused matrix multiplication and all-reduce.

Computes: output = all_reduce(A @ B + bias)

Parameters:
  • output_tensor – Output tensor (M, N)

  • A – Input matrix A (M, K)

  • B – Input matrix B (K, N)

  • bias – Optional bias vector (M,) or (N,)

  • async_op – If False, performs barrier at end

  • config – Optional FusedConfig for tuning

  • workspace – Optional pre-allocated workspace

Returns:

Updated workspace object

Return type:

workspace

Example

>>> output = shmem.zeros((M, N), dtype=torch.float16)
>>> shmem.ops.matmul_all_reduce(output, A, B)

all_gather_matmul#

OpsNamespace.all_gather_matmul(output_tensor, A_sharded, B, bias=None, async_op=False, config=None, workspace=None)[source]#

Fused all-gather and matrix multiplication.

Computes: output = all_gather(A_sharded) @ B + bias

Parameters:
  • output_tensor – Output tensor (M, N)

  • A_sharded – Sharded input matrix (M, K_local)

  • B – Input matrix B (K, N) where K = K_local * world_size

  • bias – Optional bias vector (M,) or (N,)

  • async_op – If False, performs barrier at end

  • config – Optional FusedConfig for tuning

  • workspace – Optional pre-allocated workspace

Returns:

Updated workspace object

Return type:

workspace

Example

>>> K_local = K // world_size
>>> A_sharded = shmem.randn((M, K_local), dtype=torch.float16)
>>> output = shmem.zeros((M, N), dtype=torch.float16)
>>> shmem.ops.all_gather_matmul(output, A_sharded, B)

matmul_all_gather#

OpsNamespace.matmul_all_gather(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#

Fused matrix multiplication and all-gather.

Computes: output = all_gather(A @ B + bias) along M dimension

Parameters:
  • output_tensor – Output tensor (M*world_size, N)

  • A – Input matrix A (M, K)

  • B – Input matrix B (K, N)

  • bias – Optional bias vector (M,) or (N,)

  • async_op – If False, performs barrier at end

  • config – Optional FusedConfig for tuning

  • workspace – Optional pre-allocated workspace

Returns:

Updated workspace object

Return type:

workspace

Example

>>> M_local = M // world_size
>>> A = shmem.randn((M_local, K), dtype=torch.float16)
>>> output = shmem.zeros((M, N), dtype=torch.float16)
>>> shmem.ops.matmul_all_gather(output, A, B)

matmul_reduce_scatter#

OpsNamespace.matmul_reduce_scatter(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#

Fused matrix multiplication and reduce-scatter.

Computes: output = reduce_scatter(A @ B + bias) along N dimension

Parameters:
  • output_tensor – Output tensor (M, N_local) where N_local = N / world_size

  • A – Input matrix A (M, K)

  • B – Input matrix B (K, N)

  • bias – Optional bias vector (M,) or (N,)

  • async_op – If False, performs barrier at end

  • config – Optional FusedConfig for tuning

  • workspace – Optional pre-allocated workspace

Returns:

Updated workspace object

Return type:

workspace

Example

>>> N_local = N // world_size
>>> output = shmem.zeros((M, N_local), dtype=torch.float16)
>>> shmem.ops.matmul_reduce_scatter(output, A, B)