Core API Reference#

This documentation is automatically generated from source code docstrings.

tritonblas Module#

Matrix Multiplication Functions#

matmul#

matmul(a, b, out=None, enable_streamk=False, sk_grid=None)[source]#
Parameters:
  • a (torch.Tensor)

  • b (torch.Tensor)

  • out (torch.Tensor | None)

  • enable_streamk (bool | None)

  • sk_grid (int | None)

Return type:

torch.Tensor | None

matmul_lt#

matmul_lt(a, b, c, selector, enable_streamk=False)[source]#
Parameters:
  • a (torch.Tensor)

  • b (torch.Tensor)

  • c (torch.Tensor)

matmul_a8w8#

matmul_a8w8(a, b, a_scale, b_scale, c, enable_streamk=False, sk_grid=None)[source]#
Parameters:
  • a (torch.Tensor)

  • b (torch.Tensor)

  • a_scale (torch.Tensor)

  • b_scale (torch.Tensor)

  • c (torch.Tensor)

matmul_a8w8_lt#

matmul_a8w8_lt(a, b, a_scale, b_scale, c, selector, enable_streamk=False)[source]#
Parameters:
  • a (torch.Tensor)

  • b (torch.Tensor)

  • a_scale (torch.Tensor)

  • b_scale (torch.Tensor)

  • c (torch.Tensor)

matmul_fp4#

matmul_fp4(a, b, c, a_scales, b_scales, block_m=None, block_n=None, block_k=None, group_size_m=8, num_warps=8, num_stages=2)[source]#

FP4 matrix multiplication: C = A @ B

Parameters:
  • a (torch.Tensor) – Input matrix A in FP4 format (M, K//2), packed 2 elements per uint8

  • b (torch.Tensor) – Input matrix B in FP4 format (N, K//2), packed 2 elements per uint8

  • c (torch.Tensor) – Output matrix C (M, N) in bfloat16 or float16

  • a_scales (torch.Tensor) – Scales for A in e8m0 format (M, K // 32)

  • b_scales (torch.Tensor) – Scales for B in e8m0 format (N, K // 32)

  • block_m (int) – Block size for M dimension

  • block_n (int) – Block size for N dimension

  • block_k (int) – Block size for K dimension (must be multiple of 64 for FP4)

  • group_size_m (int) – Group size for M dimension tiling

  • num_warps (int) – Number of warps per thread block (default: 8)

  • num_stages (int) – Number of pipeline stages (default: 2)

Returns:

Output matrix C

Configuration Classes#

OrigamiMatmulSelector#

class OrigamiMatmulSelector(m, n, k, a_dtype, b_dtype, out_dtype, device, mx_block_size=0, streamk=False)[source]#

Bases: object

Parameters:
  • m (int)

  • n (int)

  • k (int)

  • a_dtype (torch.dtype)

  • b_dtype (torch.dtype)

  • out_dtype (torch.dtype)

  • device (torch.device)

dtype_to_str = {torch.bfloat16: 'bf16', torch.complex128: 'c64', torch.complex64: 'c32', torch.float16: 'f16', torch.float32: 'f32', torch.float64: 'f64', torch.float8_e4m3fn: 'f8', torch.float8_e4m3fnuz: 'f8', torch.float8_e5m2: 'f8', torch.float8_e5m2fnuz: 'f8', torch.int32: 'i32', torch.int8: 'i8'}#
__init__(m, n, k, a_dtype, b_dtype, out_dtype, device, mx_block_size=0, streamk=False)[source]#
Parameters:
  • m (int)

  • n (int)

  • k (int)

  • a_dtype (torch.dtype)

  • b_dtype (torch.dtype)

  • out_dtype (torch.dtype)

  • device (torch.device)

property block_m#
property block_n#
property block_k#
property group_m#
property num_sms#
property waves_per_eu#
property even_k#
property sk_grid#