Source code for tritonblas.matmul

import functools
import random
import time
from typing import Any, Dict, Optional, Tuple

import torch
from torch.library import triton_op, wrap_triton
import triton

from .kernels import persistent_matmul, streamk_matmul
from .kernels.fp4_matmul import fp4_matmul
from .origami import OrigamiMatmulSelector

_tensor_cache = {}
current_device_index = torch.cuda.current_device()
current_device = torch.cuda.get_device_properties(current_device_index)
MAX_SMS = current_device.multi_processor_count
# TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4
MAX_BLOCK_SIZE = 65536

# Global pre-allocated buffers
_global_locks = torch.empty(MAX_SMS, device="cuda", dtype=torch.uint8)
_global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32)


# Function will behave like an LRU-Cache of heuristic results
# Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily
#@functools.lru_cache(maxsize=1024)
def _make_matmul_selector(
    M: int,
    N: int,
    K: int,
    a_dtype: torch.dtype,
    b_dtype: torch.dtype,
    c_dtype: torch.dtype,
    device: torch.device,
    mx_block_size = 0,
    streamk = False
):
    # Run Heuristic Results (Only if key has not been seen before)
    return OrigamiMatmulSelector(
            M,
            N,
            K,
            a_dtype,
            b_dtype,
            c_dtype,
            device,
            mx_block_size=mx_block_size,
            streamk=streamk)


def persistent_matmul_lt(
    a: torch.Tensor,
    b: torch.Tensor,
    c: torch.Tensor,
    selector,
    bias: Optional[torch.Tensor] = None,
    a_scale: Optional[torch.Tensor] = None,
    b_scale: Optional[torch.Tensor] = None,
    quantized: bool = False,
):
    assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
    M, K = a.shape
    _, N = b.shape

    BLK_M    = selector.block_m
    BLK_N    = selector.block_n
    BLK_K    = selector.block_k
    gsize_m  = selector.group_m
    num_xcds = selector.num_sms

    total_blocks_M = triton.cdiv(M, BLK_M)
    total_blocks_N = triton.cdiv(N, BLK_N)
    total_tiles = total_blocks_M * total_blocks_N
    total_programs = total_tiles
    even_k = K % BLK_K == 0

    # TODO: Separate these configs.
    # basica configs for most of compute bound sizes
    # TODO: set these values analytically?
    num_stages = 2
    num_warps = 8
    waves_per_eu = 0
    mfmaInstrSize = 16
    kpack = 1
    #for skinny size like 4, 5120, 2880, use CACHE_MODIFIER=".cg"
    CACHE_MODIFIER_A = None
    CACHE_MODIFIER_B = None

    # Run in Data-parallel mode.
    grids = total_tiles

    # Set chunk size to same area as L2 tiles.
    chunk_size = gsize_m * gsize_m
    chunk_size = min(chunk_size, total_programs // num_xcds)

    # TODO: Support other matmul algs.
    #kk = persistent_matmul[(grids,)](
    kk = wrap_triton(persistent_matmul)[(grids,)](
        a,
        b,
        c,
        a_scale if quantized else None,  # A_scale_ptr
        b_scale if quantized else None,  # B_scale_ptr
        bias if bias is not None else None,
        M,
        N,
        K,
        a.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        bias.stride(0) if bias is not None else 0,
        stride_ak=a.stride(1),
        stride_bk=b.stride(0),
        BLOCK_SIZE_M=BLK_M,
        BLOCK_SIZE_N=BLK_N,
        BLOCK_SIZE_K=BLK_K,
        GROUP_SIZE_M=gsize_m,
        NUM_SMS=total_programs,
        NUM_XCDS=num_xcds,
        CHUNK_SIZE=chunk_size,
        BIAS=bias is not None,
        EVEN_K=even_k,
        CACHE_MODIFIER_A=CACHE_MODIFIER_A,
        CACHE_MODIFIER_B=CACHE_MODIFIER_B,
        QUANTIZED=quantized,
        num_stages=num_stages,
        num_warps=num_warps,
        waves_per_eu=waves_per_eu,
        matrix_instr_nonkdim=mfmaInstrSize,
        kpack=kpack,
        ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
    )

    return c

def streamk_matmul_lt(
    a: torch.Tensor, 
    b: torch.Tensor, 
    c: torch.Tensor, 
    selector, 
    bias: Optional[torch.Tensor] = None,
    sk_grid: Optional[int] = None,
    a_scale: Optional[torch.Tensor] = None,
    b_scale: Optional[torch.Tensor] = None,
    quantized: bool = False,
):
    assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
    M, K = a.shape
    _, N = b.shape

    BLK_M    = selector.block_m
    BLK_N    = selector.block_n
    BLK_K    = selector.block_k
    gsize_m  = selector.group_m
    num_xcds = selector.num_sms

    total_blocks_M = triton.cdiv(M, BLK_M)
    total_blocks_N = triton.cdiv(N, BLK_N)
    total_tiles = total_blocks_M * total_blocks_N
    even_k = K % BLK_K == 0

    ##
    # Grid Size
    ##
    total_programs_streamk = selector.sk_grid

    if total_programs_streamk > 0:  # Stream-K
        total_tiles_streamk = total_tiles % total_programs_streamk
    else:  # all tiles are computed using classical blocking
        total_tiles_streamk = 0

    num_stages = 2
    num_warps = 8
    waves_per_eu = 0
    mfmaInstrSize = 16
    kpack = 1
    #for skinny size like 4, 5120, 2880, use CACHE_MODIFIER=".cg"
    CACHE_MODIFIER_A = None
    CACHE_MODIFIER_B = None

    if sk_grid is not None:
        total_programs_streamk = sk_grid

    grids = total_programs_streamk
    block_size = BLK_M * BLK_N

    # Use global buffers with optimized zeroing
    if grids <= MAX_SMS and block_size <= MAX_BLOCK_SIZE:
        locks = _global_locks[:grids]
        P = _global_P[:grids, :block_size]
    else:
        locks = torch.empty(grids, device="cuda", dtype=torch.uint8)
        P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32)

    # Set chunk size to same area as L2 tiles.
    chunk_size = gsize_m * gsize_m
    chunk_size = min(chunk_size, grids // num_xcds) 

    #kk = streamk_matmul[(grids,)](
    kk = wrap_triton(streamk_matmul)[(grids,)](
        a,
        b,
        c,
        a_scale if quantized else None,  # A_scale_ptr
        b_scale if quantized else None,  # B_scale_ptr
        bias if bias is not None else None,
        P,
        locks,
        M,
        N,
        K,
        a.stride(0),
        b.stride(1),
        c.stride(0),
        c.stride(1),
        bias.stride(0) if bias is not None else None,
        stride_ak=a.stride(1),
        stride_bk=b.stride(0),
        BLOCK_SIZE_M=BLK_M,
        BLOCK_SIZE_N=BLK_N,
        BLOCK_SIZE_K=BLK_K,
        GROUP_SIZE_M=gsize_m,
        NUM_SMS=grids,
        NUM_XCDS=num_xcds,
        CHUNK_SIZE=chunk_size,
        STREAMK_TILES=total_tiles_streamk,
        BIAS=bias is not None,
        EVEN_K=even_k,
        CACHE_MODIFIER_A=CACHE_MODIFIER_A,
        CACHE_MODIFIER_B=CACHE_MODIFIER_B,
        QUANTIZED=quantized,
        num_stages=num_stages,
        num_warps=num_warps,
        waves_per_eu=waves_per_eu,
        matrix_instr_nonkdim=mfmaInstrSize,
        kpack=kpack,
    )

    return c

[docs] def matmul_lt( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: return streamk_matmul_lt(a, b, c, selector) else: return persistent_matmul_lt(a, b, c, selector)
[docs] def matmul_a8w8_lt( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" if enable_streamk: return streamk_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True) else: return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True)
@triton_op("tritonblas::_matmul", mutates_args={}) def _matmul( a: torch.Tensor, b: torch.Tensor, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None, ) -> torch.Tensor: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape # Allocate an output tensor out = a.new_empty(M, N) # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) else: return persistent_matmul_lt(a, b, out, selector) def _setup_context_matmul_backwards( ctx: Any, inputs: tuple[Any, ...], output: Any ): a, b, enable_streamk, sk_grid = inputs ctx.save_for_backward(a, b) ctx.enable_streamk = enable_streamk ctx.sk_grid = sk_grid def _matmul_backwards( ctx: Any, grad_output: torch.Tensor ): a, b = ctx.saved_tensors enable_streamk = ctx.enable_streamk sk_grid = ctx.sk_grid # Make grad_output contiguous grad_output_cont = grad_output.contiguous() # grad_a = grad_output @ b^T b_t = b.T.contiguous() grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) # grad_b = a^T @ grad_output a_t = a.T.contiguous() grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) # tuple[a, b, enable_streamk, sk_grid] # First 2 must be in the order that matches matmul()'s forward args # Last 2 are not part of the gradient and so are None return grad_a, grad_b, None, None _matmul.register_autograd(_matmul_backwards, setup_context=_setup_context_matmul_backwards) @triton_op("tritonblas::_matmul_out", mutates_args={'out'}) def _matmul_out( a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None, ) -> None: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk) if enable_streamk: streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid) else: persistent_matmul_lt(a, b, out, selector) # Custom torch ops cannot return a value which is an alias of an input. So # even though torch returns a pointer to the out arg when used, we can't. return None
[docs] def matmul( a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None ) -> Optional[torch.Tensor]: # If no out tensor provided - we do the allocation - we support autograd if out is None: return _matmul(a, b, enable_streamk, sk_grid) # If out tensor provided - in-place - we do NOT support autograd # Check for autograd conditions (global and per-tensor) if torch.is_grad_enabled() and ( a.requires_grad or b.requires_grad or out.requires_grad ): raise RuntimeError( "tritonblas.matmul(): functions with out=... arguments don't support " "automatic differentiation, but one of the arguments requires grad." ) return _matmul_out(a, b, out, enable_streamk, sk_grid)
[docs] def matmul_a8w8( a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, c: torch.Tensor, enable_streamk=False, sk_grid=None, ): assert a.shape[1] == b.shape[0], "Incompatible Dimensions" M, K = a.shape _, N = b.shape selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk) if enable_streamk: return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, a_scale=a_scale, b_scale=b_scale, quantized=True) else: return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True)
[docs] def matmul_fp4( a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, block_m: int = None, #Overrides Origami value block_n: int = None, #Overrides Origami value block_k: int = None, #Overrides Origami value group_size_m: int = 8, #Overrides Origami value num_warps: int = 8, num_stages: int = 2, ): """ FP4 matrix multiplication: C = A @ B Args: a: Input matrix A in FP4 format (M, K//2), packed 2 elements per uint8 b: Input matrix B in FP4 format (N, K//2), packed 2 elements per uint8 c: Output matrix C (M, N) in bfloat16 or float16 a_scales: Scales for A in e8m0 format (M, K // 32) b_scales: Scales for B in e8m0 format (N, K // 32) block_m: Block size for M dimension block_n: Block size for N dimension block_k: Block size for K dimension (must be multiple of 64 for FP4) group_size_m: Group size for M dimension tiling num_warps: Number of warps per thread block (default: 8) num_stages: Number of pipeline stages (default: 2) Returns: Output matrix C """ M, K = a.shape _, N = b.shape num_xcds = 8 if(block_m == None): selector = _make_matmul_selector(M, N, K, "f4", "f4", c.dtype, a.device, mx_block_size=32) block_m = selector.block_m block_n = selector.block_n block_k = selector.block_k group_size_m = selector.group_m num_xcds = selector.num_sms if(block_m < M): block_m=128 if(block_n < N): block_n=128 if(block_k < K): block_k=128 #print(f"Selected {block_m}x{block_n}x{block_k}") # Get actual dimensions (accounting for packing) M = a.shape[0] K = a.shape[1] * 2 # Unpacked K dimension N = b.shape[0] # B has shape (N, K//2) # Verify dimensions are compatible assert b.shape[1] * 2 == K, f"Incompatible Dimensions: A has K={K}, B has K={b.shape[1] * 2}" # Transpose B to match kernel expectations (kernel expects B as K x N) b = b.T # Ensure block_k is appropriate for FP4 (must be multiple of 64) assert block_k % 64 == 0, "BLOCK_K must be multiple of 64 for FP4" total_blocks_M = triton.cdiv(M, block_m) total_blocks_N = triton.cdiv(N, block_n) total_tiles = total_blocks_M * total_blocks_N # Set chunk size to same area as L2 tiles chunk_size = group_size_m * group_size_m chunk_size = min(chunk_size, max(1, total_tiles // num_xcds)) grid = (total_tiles,) fp4_matmul[grid]( a, b, c, a_scales, b_scales, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), a_scales.stride(0), a_scales.stride(1), b_scales.stride(0), b_scales.stride(1), BLOCK_SIZE_M=block_m, BLOCK_SIZE_N=block_n, BLOCK_SIZE_K=block_k, GROUP_SIZE_M=group_size_m, NUM_SMS=total_tiles, NUM_XCDS=num_xcds, CHUNK_SIZE=chunk_size, num_stages=num_stages, num_warps=num_warps, ) return c
@triton_op("tritonblas::_addmm", mutates_args={}) def _addmm( bias: torch.Tensor, a: torch.Tensor, b: torch.Tensor, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None, ) -> torch.Tensor: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) # Allocate an output tensor out = a.new_empty(M, N) if enable_streamk: return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) else: return persistent_matmul_lt(a, b, out, selector, bias=bias) def _setup_context_addmm_backwards( ctx: Any, inputs: tuple[Any, ...], output: Any ): bias, a, b, enable_streamk, sk_grid = inputs ctx.save_for_backward(a, b) ctx.enable_streamk = enable_streamk ctx.sk_grid = sk_grid def _addmm_backwards( ctx: Any, grad_output: torch.Tensor ): a, b = ctx.saved_tensors enable_streamk = ctx.enable_streamk sk_grid = ctx.sk_grid # Make grad_output contiguous grad_output_cont = grad_output.contiguous() # grad_a = grad_output @ b^T b_t = b.T.contiguous() grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid) # grad_b = a^T @ grad_output a_t = a.T.contiguous() grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid) # grad_bias = sum(grad_output) grad_bias = grad_output.sum(dim=0) # tuple[bias, a, b, enable_streamk, sk_grid] # First 3 must be in the order that matches addmm()'s forward args # Last 2 are not part of the gradient and so are None return grad_bias, grad_a, grad_b, None, None _addmm.register_autograd(_addmm_backwards, setup_context=_setup_context_addmm_backwards) @triton_op("tritonblas::_addmm_out", mutates_args={'out'}) def _addmm_out( bias: torch.Tensor, a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None, ) -> None: assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions" M, K = a.shape _, N = b.shape # Query Origami for solution selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk) if enable_streamk: streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid) else: persistent_matmul_lt(a, b, out, selector, bias=bias) # Custom torch ops cannot return a value which is an alias of an input. So # even though torch returns a pointer to the out arg when used, we can't. return None def addmm( bias: torch.Tensor, a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None, enable_streamk: Optional[bool] = False, sk_grid: Optional[int] = None ) -> Optional[torch.Tensor]: # If no out tensor provided - we do the allocation - we support autograd if out is None: return _addmm(bias, a, b, enable_streamk, sk_grid) # If out tensor provided - in-place - we do NOT support autograd # Check for autograd conditions (global and per-tensor) if torch.is_grad_enabled() and ( bias.requires_grad or a.requires_grad or b.requires_grad or out.requires_grad ): raise RuntimeError( "tritonblas.addmm(): functions with out=... arguments don't support " "automatic differentiation, but one of the arguments requires grad." ) return _addmm_out(bias, a, b, out, enable_streamk, sk_grid)