Core API Reference#
This documentation is automatically generated from source code docstrings.
tritonblas Module#
Matrix Multiplication Functions#
matmul#
matmul_lt#
matmul_a8w8#
matmul_a8w8_lt#
matmul_fp4#
- matmul_fp4(a, b, c, a_scales, b_scales, block_m=None, block_n=None, block_k=None, group_size_m=8, num_warps=8, num_stages=2)[source]#
FP4 matrix multiplication: C = A @ B
- Parameters:
a (torch.Tensor) – Input matrix A in FP4 format (M, K//2), packed 2 elements per uint8
b (torch.Tensor) – Input matrix B in FP4 format (N, K//2), packed 2 elements per uint8
c (torch.Tensor) – Output matrix C (M, N) in bfloat16 or float16
a_scales (torch.Tensor) – Scales for A in e8m0 format (M, K // 32)
b_scales (torch.Tensor) – Scales for B in e8m0 format (N, K // 32)
block_m (int) – Block size for M dimension
block_n (int) – Block size for N dimension
block_k (int) – Block size for K dimension (must be multiple of 64 for FP4)
group_size_m (int) – Group size for M dimension tiling
num_warps (int) – Number of warps per thread block (default: 8)
num_stages (int) – Number of pipeline stages (default: 2)
- Returns:
Output matrix C
Configuration Classes#
OrigamiMatmulSelector#
- class OrigamiMatmulSelector(m, n, k, a_dtype, b_dtype, out_dtype, device, mx_block_size=0, streamk=False)[source]#
Bases:
object- Parameters:
- dtype_to_str = {torch.bfloat16: 'bf16', torch.complex128: 'c64', torch.complex64: 'c32', torch.float16: 'f16', torch.float32: 'f32', torch.float64: 'f64', torch.float8_e4m3fn: 'f8', torch.float8_e4m3fnuz: 'f8', torch.float8_e5m2: 'f8', torch.float8_e5m2fnuz: 'f8', torch.int32: 'i32', torch.int8: 'i8'}#
- property block_m#
- property block_n#
- property block_k#
- property group_m#
- property num_sms#
- property waves_per_eu#
- property even_k#
- property sk_grid#