Fused GEMM + CCL Operations#
Fused matrix multiplication and collective communication operations accessible via the ops property on the Iris instance (e.g. ctx.ops.matmul_all_reduce(...)).
matmul_all_reduce#
- OpsNamespace.matmul_all_reduce(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#
Fused matrix multiplication and all-reduce.
Computes: output = all_reduce(A @ B + bias)
- Parameters:
output_tensor – Output tensor (M, N)
A – Input matrix A (M, K)
B – Input matrix B (K, N)
bias – Optional bias vector (M,) or (N,)
async_op – If False, performs barrier at end
config – Optional FusedConfig for tuning
workspace – Optional pre-allocated workspace
- Returns:
Updated workspace object
- Return type:
workspace
Example
>>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.matmul_all_reduce(output, A, B)
all_gather_matmul#
- OpsNamespace.all_gather_matmul(output_tensor, A_sharded, B, bias=None, async_op=False, config=None, workspace=None)[source]#
Fused all-gather and matrix multiplication.
Computes: output = all_gather(A_sharded) @ B + bias
- Parameters:
output_tensor – Output tensor (M, N)
A_sharded – Sharded input matrix (M, K_local)
B – Input matrix B (K, N) where K = K_local * world_size
bias – Optional bias vector (M,) or (N,)
async_op – If False, performs barrier at end
config – Optional FusedConfig for tuning
workspace – Optional pre-allocated workspace
- Returns:
Updated workspace object
- Return type:
workspace
Example
>>> K_local = K // world_size >>> A_sharded = shmem.randn((M, K_local), dtype=torch.float16) >>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.all_gather_matmul(output, A_sharded, B)
matmul_all_gather#
- OpsNamespace.matmul_all_gather(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#
Fused matrix multiplication and all-gather.
Computes: output = all_gather(A @ B + bias) along M dimension
- Parameters:
output_tensor – Output tensor (M*world_size, N)
A – Input matrix A (M, K)
B – Input matrix B (K, N)
bias – Optional bias vector (M,) or (N,)
async_op – If False, performs barrier at end
config – Optional FusedConfig for tuning
workspace – Optional pre-allocated workspace
- Returns:
Updated workspace object
- Return type:
workspace
Example
>>> M_local = M // world_size >>> A = shmem.randn((M_local, K), dtype=torch.float16) >>> output = shmem.zeros((M, N), dtype=torch.float16) >>> shmem.ops.matmul_all_gather(output, A, B)
matmul_reduce_scatter#
- OpsNamespace.matmul_reduce_scatter(output_tensor, A, B, bias=None, async_op=False, config=None, workspace=None)[source]#
Fused matrix multiplication and reduce-scatter.
Computes: output = reduce_scatter(A @ B + bias) along N dimension
- Parameters:
output_tensor – Output tensor (M, N_local) where N_local = N / world_size
A – Input matrix A (M, K)
B – Input matrix B (K, N)
bias – Optional bias vector (M,) or (N,)
async_op – If False, performs barrier at end
config – Optional FusedConfig for tuning
workspace – Optional pre-allocated workspace
- Returns:
Updated workspace object
- Return type:
workspace
Example
>>> N_local = N // world_size >>> output = shmem.zeros((M, N_local), dtype=torch.float16) >>> shmem.ops.matmul_reduce_scatter(output, A, B)