import functools
import random
import time
from typing import Any, Dict, Optional, Tuple
import torch
from torch.library import triton_op, wrap_triton
import triton
from .kernels import persistent_matmul, streamk_matmul
from .kernels.fp4_matmul import fp4_matmul
from .origami import OrigamiMatmulSelector
_tensor_cache = {}
current_device_index = torch.cuda.current_device()
current_device = torch.cuda.get_device_properties(current_device_index)
MAX_SMS = current_device.multi_processor_count
# TODO: 256x256 for fp16/bf16, need adjust for fp8/fp4
MAX_BLOCK_SIZE = 65536
# Global pre-allocated buffers
_global_locks = torch.empty(MAX_SMS, device="cuda", dtype=torch.uint8)
_global_P = torch.empty(MAX_SMS, MAX_BLOCK_SIZE, device="cuda", dtype=torch.float32)
# Function will behave like an LRU-Cache of heuristic results
# Saves several microseconds for previously seen problems by not rerunning the heuristic unnecessarily
#@functools.lru_cache(maxsize=1024)
def _make_matmul_selector(
M: int,
N: int,
K: int,
a_dtype: torch.dtype,
b_dtype: torch.dtype,
c_dtype: torch.dtype,
device: torch.device,
mx_block_size = 0,
streamk = False
):
# Run Heuristic Results (Only if key has not been seen before)
return OrigamiMatmulSelector(
M,
N,
K,
a_dtype,
b_dtype,
c_dtype,
device,
mx_block_size=mx_block_size,
streamk=streamk)
def persistent_matmul_lt(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
selector,
bias: Optional[torch.Tensor] = None,
a_scale: Optional[torch.Tensor] = None,
b_scale: Optional[torch.Tensor] = None,
quantized: bool = False,
):
assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
M, K = a.shape
_, N = b.shape
BLK_M = selector.block_m
BLK_N = selector.block_n
BLK_K = selector.block_k
gsize_m = selector.group_m
num_xcds = selector.num_sms
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
total_programs = total_tiles
even_k = K % BLK_K == 0
# TODO: Separate these configs.
# basica configs for most of compute bound sizes
# TODO: set these values analytically?
num_stages = 2
num_warps = 8
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 1
#for skinny size like 4, 5120, 2880, use CACHE_MODIFIER=".cg"
CACHE_MODIFIER_A = None
CACHE_MODIFIER_B = None
# Run in Data-parallel mode.
grids = total_tiles
# Set chunk size to same area as L2 tiles.
chunk_size = gsize_m * gsize_m
chunk_size = min(chunk_size, total_programs // num_xcds)
# TODO: Support other matmul algs.
#kk = persistent_matmul[(grids,)](
kk = wrap_triton(persistent_matmul)[(grids,)](
a,
b,
c,
a_scale if quantized else None, # A_scale_ptr
b_scale if quantized else None, # B_scale_ptr
bias if bias is not None else None,
M,
N,
K,
a.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
bias.stride(0) if bias is not None else 0,
stride_ak=a.stride(1),
stride_bk=b.stride(0),
BLOCK_SIZE_M=BLK_M,
BLOCK_SIZE_N=BLK_N,
BLOCK_SIZE_K=BLK_K,
GROUP_SIZE_M=gsize_m,
NUM_SMS=total_programs,
NUM_XCDS=num_xcds,
CHUNK_SIZE=chunk_size,
BIAS=bias is not None,
EVEN_K=even_k,
CACHE_MODIFIER_A=CACHE_MODIFIER_A,
CACHE_MODIFIER_B=CACHE_MODIFIER_B,
QUANTIZED=quantized,
num_stages=num_stages,
num_warps=num_warps,
waves_per_eu=waves_per_eu,
matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack,
ALLOW_TF32=torch.backends.cuda.matmul.allow_tf32,
)
return c
def streamk_matmul_lt(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
selector,
bias: Optional[torch.Tensor] = None,
sk_grid: Optional[int] = None,
a_scale: Optional[torch.Tensor] = None,
b_scale: Optional[torch.Tensor] = None,
quantized: bool = False,
):
assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
M, K = a.shape
_, N = b.shape
BLK_M = selector.block_m
BLK_N = selector.block_n
BLK_K = selector.block_k
gsize_m = selector.group_m
num_xcds = selector.num_sms
total_blocks_M = triton.cdiv(M, BLK_M)
total_blocks_N = triton.cdiv(N, BLK_N)
total_tiles = total_blocks_M * total_blocks_N
even_k = K % BLK_K == 0
##
# Grid Size
##
total_programs_streamk = selector.sk_grid
if total_programs_streamk > 0: # Stream-K
total_tiles_streamk = total_tiles % total_programs_streamk
else: # all tiles are computed using classical blocking
total_tiles_streamk = 0
num_stages = 2
num_warps = 8
waves_per_eu = 0
mfmaInstrSize = 16
kpack = 1
#for skinny size like 4, 5120, 2880, use CACHE_MODIFIER=".cg"
CACHE_MODIFIER_A = None
CACHE_MODIFIER_B = None
if sk_grid is not None:
total_programs_streamk = sk_grid
grids = total_programs_streamk
block_size = BLK_M * BLK_N
# Use global buffers with optimized zeroing
if grids <= MAX_SMS and block_size <= MAX_BLOCK_SIZE:
locks = _global_locks[:grids]
P = _global_P[:grids, :block_size]
else:
locks = torch.empty(grids, device="cuda", dtype=torch.uint8)
P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32)
# Set chunk size to same area as L2 tiles.
chunk_size = gsize_m * gsize_m
chunk_size = min(chunk_size, grids // num_xcds)
#kk = streamk_matmul[(grids,)](
kk = wrap_triton(streamk_matmul)[(grids,)](
a,
b,
c,
a_scale if quantized else None, # A_scale_ptr
b_scale if quantized else None, # B_scale_ptr
bias if bias is not None else None,
P,
locks,
M,
N,
K,
a.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
bias.stride(0) if bias is not None else None,
stride_ak=a.stride(1),
stride_bk=b.stride(0),
BLOCK_SIZE_M=BLK_M,
BLOCK_SIZE_N=BLK_N,
BLOCK_SIZE_K=BLK_K,
GROUP_SIZE_M=gsize_m,
NUM_SMS=grids,
NUM_XCDS=num_xcds,
CHUNK_SIZE=chunk_size,
STREAMK_TILES=total_tiles_streamk,
BIAS=bias is not None,
EVEN_K=even_k,
CACHE_MODIFIER_A=CACHE_MODIFIER_A,
CACHE_MODIFIER_B=CACHE_MODIFIER_B,
QUANTIZED=quantized,
num_stages=num_stages,
num_warps=num_warps,
waves_per_eu=waves_per_eu,
matrix_instr_nonkdim=mfmaInstrSize,
kpack=kpack,
)
return c
[docs]
def matmul_lt(
a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False
):
assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
if enable_streamk:
return streamk_matmul_lt(a, b, c, selector)
else:
return persistent_matmul_lt(a, b, c, selector)
[docs]
def matmul_a8w8_lt(
a: torch.Tensor, b: torch.Tensor, a_scale: torch.Tensor, b_scale: torch.Tensor, c: torch.Tensor, selector, enable_streamk=False
):
assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
if enable_streamk:
return streamk_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True)
else:
return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True)
@triton_op("tritonblas::_matmul", mutates_args={})
def _matmul(
a: torch.Tensor,
b: torch.Tensor,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None,
) -> torch.Tensor:
assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions"
M, K = a.shape
_, N = b.shape
# Allocate an output tensor
out = a.new_empty(M, N)
# Query Origami for solution
selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk)
if enable_streamk:
return streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid)
else:
return persistent_matmul_lt(a, b, out, selector)
def _setup_context_matmul_backwards(
ctx: Any,
inputs: tuple[Any, ...],
output: Any
):
a, b, enable_streamk, sk_grid = inputs
ctx.save_for_backward(a, b)
ctx.enable_streamk = enable_streamk
ctx.sk_grid = sk_grid
def _matmul_backwards(
ctx: Any,
grad_output: torch.Tensor
):
a, b = ctx.saved_tensors
enable_streamk = ctx.enable_streamk
sk_grid = ctx.sk_grid
# Make grad_output contiguous
grad_output_cont = grad_output.contiguous()
# grad_a = grad_output @ b^T
b_t = b.T.contiguous()
grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid)
# grad_b = a^T @ grad_output
a_t = a.T.contiguous()
grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid)
# tuple[a, b, enable_streamk, sk_grid]
# First 2 must be in the order that matches matmul()'s forward args
# Last 2 are not part of the gradient and so are None
return grad_a, grad_b, None, None
_matmul.register_autograd(_matmul_backwards,
setup_context=_setup_context_matmul_backwards)
@triton_op("tritonblas::_matmul_out", mutates_args={'out'})
def _matmul_out(
a: torch.Tensor,
b: torch.Tensor,
out: torch.Tensor,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None,
) -> None:
assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions"
M, K = a.shape
_, N = b.shape
# Query Origami for solution
selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, out.dtype, a.device, streamk=enable_streamk)
if enable_streamk:
streamk_matmul_lt(a, b, out, selector, sk_grid=sk_grid)
else:
persistent_matmul_lt(a, b, out, selector)
# Custom torch ops cannot return a value which is an alias of an input. So
# even though torch returns a pointer to the out arg when used, we can't.
return None
[docs]
def matmul(
a: torch.Tensor,
b: torch.Tensor,
out: Optional[torch.Tensor] = None,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None
) -> Optional[torch.Tensor]:
# If no out tensor provided - we do the allocation - we support autograd
if out is None:
return _matmul(a, b, enable_streamk, sk_grid)
# If out tensor provided - in-place - we do NOT support autograd
# Check for autograd conditions (global and per-tensor)
if torch.is_grad_enabled() and (
a.requires_grad
or b.requires_grad
or out.requires_grad
):
raise RuntimeError(
"tritonblas.matmul(): functions with out=... arguments don't support "
"automatic differentiation, but one of the arguments requires grad."
)
return _matmul_out(a, b, out, enable_streamk, sk_grid)
[docs]
def matmul_a8w8(
a: torch.Tensor,
b: torch.Tensor,
a_scale: torch.Tensor,
b_scale: torch.Tensor,
c: torch.Tensor,
enable_streamk=False,
sk_grid=None,
):
assert a.shape[1] == b.shape[0], "Incompatible Dimensions"
M, K = a.shape
_, N = b.shape
selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, c.dtype, a.device, streamk=enable_streamk)
if enable_streamk:
return streamk_matmul_lt(a, b, c, selector, sk_grid=sk_grid, a_scale=a_scale, b_scale=b_scale, quantized=True)
else:
return persistent_matmul_lt(a, b, c, selector, a_scale=a_scale, b_scale=b_scale, quantized=True)
[docs]
def matmul_fp4(
a: torch.Tensor,
b: torch.Tensor,
c: torch.Tensor,
a_scales: torch.Tensor,
b_scales: torch.Tensor,
block_m: int = None, #Overrides Origami value
block_n: int = None, #Overrides Origami value
block_k: int = None, #Overrides Origami value
group_size_m: int = 8, #Overrides Origami value
num_warps: int = 8,
num_stages: int = 2,
):
"""
FP4 matrix multiplication: C = A @ B
Args:
a: Input matrix A in FP4 format (M, K//2), packed 2 elements per uint8
b: Input matrix B in FP4 format (N, K//2), packed 2 elements per uint8
c: Output matrix C (M, N) in bfloat16 or float16
a_scales: Scales for A in e8m0 format (M, K // 32)
b_scales: Scales for B in e8m0 format (N, K // 32)
block_m: Block size for M dimension
block_n: Block size for N dimension
block_k: Block size for K dimension (must be multiple of 64 for FP4)
group_size_m: Group size for M dimension tiling
num_warps: Number of warps per thread block (default: 8)
num_stages: Number of pipeline stages (default: 2)
Returns:
Output matrix C
"""
M, K = a.shape
_, N = b.shape
num_xcds = 8
if(block_m == None):
selector = _make_matmul_selector(M, N, K, "f4", "f4", c.dtype, a.device, mx_block_size=32)
block_m = selector.block_m
block_n = selector.block_n
block_k = selector.block_k
group_size_m = selector.group_m
num_xcds = selector.num_sms
if(block_m < M):
block_m=128
if(block_n < N):
block_n=128
if(block_k < K):
block_k=128
#print(f"Selected {block_m}x{block_n}x{block_k}")
# Get actual dimensions (accounting for packing)
M = a.shape[0]
K = a.shape[1] * 2 # Unpacked K dimension
N = b.shape[0] # B has shape (N, K//2)
# Verify dimensions are compatible
assert b.shape[1] * 2 == K, f"Incompatible Dimensions: A has K={K}, B has K={b.shape[1] * 2}"
# Transpose B to match kernel expectations (kernel expects B as K x N)
b = b.T
# Ensure block_k is appropriate for FP4 (must be multiple of 64)
assert block_k % 64 == 0, "BLOCK_K must be multiple of 64 for FP4"
total_blocks_M = triton.cdiv(M, block_m)
total_blocks_N = triton.cdiv(N, block_n)
total_tiles = total_blocks_M * total_blocks_N
# Set chunk size to same area as L2 tiles
chunk_size = group_size_m * group_size_m
chunk_size = min(chunk_size, max(1, total_tiles // num_xcds))
grid = (total_tiles,)
fp4_matmul[grid](
a,
b,
c,
a_scales,
b_scales,
M,
N,
K,
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
c.stride(0),
c.stride(1),
a_scales.stride(0),
a_scales.stride(1),
b_scales.stride(0),
b_scales.stride(1),
BLOCK_SIZE_M=block_m,
BLOCK_SIZE_N=block_n,
BLOCK_SIZE_K=block_k,
GROUP_SIZE_M=group_size_m,
NUM_SMS=total_tiles,
NUM_XCDS=num_xcds,
CHUNK_SIZE=chunk_size,
num_stages=num_stages,
num_warps=num_warps,
)
return c
@triton_op("tritonblas::_addmm", mutates_args={})
def _addmm(
bias: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None,
) -> torch.Tensor:
assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions"
M, K = a.shape
_, N = b.shape
# Query Origami for solution
selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk)
# Allocate an output tensor
out = a.new_empty(M, N)
if enable_streamk:
return streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid)
else:
return persistent_matmul_lt(a, b, out, selector, bias=bias)
def _setup_context_addmm_backwards(
ctx: Any,
inputs: tuple[Any, ...],
output: Any
):
bias, a, b, enable_streamk, sk_grid = inputs
ctx.save_for_backward(a, b)
ctx.enable_streamk = enable_streamk
ctx.sk_grid = sk_grid
def _addmm_backwards(
ctx: Any,
grad_output: torch.Tensor
):
a, b = ctx.saved_tensors
enable_streamk = ctx.enable_streamk
sk_grid = ctx.sk_grid
# Make grad_output contiguous
grad_output_cont = grad_output.contiguous()
# grad_a = grad_output @ b^T
b_t = b.T.contiguous()
grad_a = matmul(grad_output_cont, b_t, enable_streamk=enable_streamk, sk_grid=sk_grid)
# grad_b = a^T @ grad_output
a_t = a.T.contiguous()
grad_b = matmul(a_t, grad_output_cont, enable_streamk=enable_streamk, sk_grid=sk_grid)
# grad_bias = sum(grad_output)
grad_bias = grad_output.sum(dim=0)
# tuple[bias, a, b, enable_streamk, sk_grid]
# First 3 must be in the order that matches addmm()'s forward args
# Last 2 are not part of the gradient and so are None
return grad_bias, grad_a, grad_b, None, None
_addmm.register_autograd(_addmm_backwards,
setup_context=_setup_context_addmm_backwards)
@triton_op("tritonblas::_addmm_out", mutates_args={'out'})
def _addmm_out(
bias: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
out: torch.Tensor,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None,
) -> None:
assert a.shape[1] == b.shape[0], "Incompatible A-B Dimensions"
M, K = a.shape
_, N = b.shape
# Query Origami for solution
selector = _make_matmul_selector(M, N, K, a.dtype, b.dtype, bias.dtype, a.device, streamk=enable_streamk)
if enable_streamk:
streamk_matmul_lt(a, b, out, selector, bias=bias, sk_grid=sk_grid)
else:
persistent_matmul_lt(a, b, out, selector, bias=bias)
# Custom torch ops cannot return a value which is an alias of an input. So
# even though torch returns a pointer to the out arg when used, we can't.
return None
def addmm(
bias: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
out: Optional[torch.Tensor] = None,
enable_streamk: Optional[bool] = False,
sk_grid: Optional[int] = None
) -> Optional[torch.Tensor]:
# If no out tensor provided - we do the allocation - we support autograd
if out is None:
return _addmm(bias, a, b, enable_streamk, sk_grid)
# If out tensor provided - in-place - we do NOT support autograd
# Check for autograd conditions (global and per-tensor)
if torch.is_grad_enabled() and (
bias.requires_grad
or a.requires_grad
or b.requires_grad
or out.requires_grad
):
raise RuntimeError(
"tritonblas.addmm(): functions with out=... arguments don't support "
"automatic differentiation, but one of the arguments requires grad."
)
return _addmm_out(bias, a, b, out, enable_streamk, sk_grid)