Stages Reference#
This documentation is automatically generated from source code docstrings.
The stages module provides composable device-side abstractions for building
high-performance GEMM kernels in Triton.
Note
All stages APIs are device-side only — they execute within @triton.jit
kernels on the GPU. They cannot be called from host Python code.
Type annotations shown as MagicMock represent Triton types (tl.tensor,
tl.constexpr) that are mocked during documentation generation.
Module Reference#
Composable stage abstractions for tritonblas GEMM kernels.
Key abstractions:
GemmContext: Accumulator context withreduce_axis()plus all config parametersScheduleContext: Unified scheduling withpersistent_tile_range()/get_tile()InputView,OutputView: Matrix views withtile_ptrs()for memory accessScaleView,BiasView: Epilogue views for quantized GEMM (scale and bias)Tile: 2D tile with coordinates (pid_m, pid_n) and shape (block_m, block_n)
Factory functions:
make_input_view(),make_tensor_view(): Create InputView for A and B matricesmake_output_view(): Create OutputView for C matrix with epilogue supportmake_scale_view(): Create ScaleView for quantization scalesmake_bias_view(): Create BiasView for bias addition
Example
from tritonblas.kernels.stages import (
ScheduleContext, GemmContext,
make_tensor_view, make_output_view,
make_scale_view, make_bias_view,
)
@triton.jit
def kernel(A, B, C, A_scale_ptr, B_scale_ptr, bias_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
stride_bias, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr,
EVEN_K: tl.constexpr, BIAS: tl.constexpr, ...):
# Create matrix views - just describe your matrices
tensorA = make_tensor_view(A, M, K, stride_am, stride_ak)
tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn)
tensorC = make_output_view(C, M, N, stride_cm, stride_cn)
# Create epilogue views (optional scale and bias)
scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N) if A_scale_ptr is not None else None
bias_view = make_bias_view(bias_ptr, M, stride_bias) if BIAS else None
# Construct GemmContext on device with ALL parameters
ctx = GemmContext(
BLOCK_M, BLOCK_N, BLOCK_K, NUM_SMS, NUM_XCDS, GROUP_SIZE_M,
even_k=EVEN_K,
)
# Create schedule from GemmContext
sched = ScheduleContext(M, N, K, ctx)
# Persistent loop
start, total, stride = sched.persistent_tile_range()
for tile_id in range(start, total, stride):
out_tile = sched.get_tile_from_idx(tile_id)
# Compute GEMM
acc = ctx.reduce_axis(tensorA, tensorB, out_tile)
# Store with epilogue: scale -> bias -> convert -> store
tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
- class Tile(pid_m, pid_n, block_m, block_n)[source]#
Bases:
object2D tile with coordinates and shape.
Stores runtime coordinates (pid_m, pid_n) and compile-time block sizes.
Use
ScheduleContextmethods to create tiles:# From linear tile_id out_tile = sched.get_tile_from_idx(tile_id) # From 2D coordinates out_tile = sched.get_tile_from_coord(pid_m, pid_n) # Direct construction is also supported out_tile = Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)
Example
# Output tile from scheduler out_tile = sched.get_tile_from_idx(tile_id) # Input A tile at k offset a_tile = sched.get_tile_from_coord(pid_m, k // BLOCK_K) # Input B tile at k offset b_tile = sched.get_tile_from_coord(k // BLOCK_K, pid_n)
- Parameters:
pid_m (tl.tensor)
pid_n (tl.tensor)
block_m (tl.constexpr)
block_n (tl.constexpr)
- __init__(pid_m, pid_n, block_m, block_n)[source]#
Create a tile with runtime coordinates and compile-time sizes.
- Parameters:
pid_m – Tile coordinate in M dimension
pid_n – Tile coordinate in N dimension
block_m – Block size in M dimension (constexpr)
block_n – Block size in N dimension (constexpr)
- pid_m: tl.tensor#
- pid_n: tl.tensor#
- block_m: tl.constexpr#
- block_n: tl.constexpr#
- indices()[source]#
Compute row and column indices for this tile.
- Returns:
Row indices [BLOCK_M], column indices [BLOCK_N]
- Return type:
rm, rn
- layout(M, N)[source]#
Compute memory layout with bounds checking.
- Parameters:
M – Total rows
N – Total columns
- Returns:
Row indices, column indices, bounds mask
- Return type:
rm, rn, mask
Note
The mask is computed from RAW indices before modulo wrapping. This is critical for correct boundary handling - e.g., when K doesn’t divide evenly by BLOCK_K, the boundary tile needs proper masking to avoid reading garbage data.
- scale(acc, A_scale_ptr, B_scale_ptr, M, N, stride_a=1, stride_b=1)[source]#
Apply quantization scales to accumulator.
- Parameters:
acc – Accumulator tensor [BLOCK_M, BLOCK_N]
A_scale_ptr – Pointer to A scales (per-row)
B_scale_ptr – Pointer to B scales (per-column)
M – Matrix dimensions for bounds checking
N – Matrix dimensions for bounds checking
stride_a – Stride for A scales (default: 1)
stride_b – Stride for B scales (default: 1)
- Returns:
Scaled accumulator as float32
Example:
acc = tile.scale(acc, A_scale_ptr, B_scale_ptr, M, N)
- bias(acc, bias_ptr, M, stride_bias=1)[source]#
Add bias vector to accumulator.
- Parameters:
acc – Accumulator tensor [BLOCK_M, BLOCK_N]
bias_ptr – Pointer to bias vector
M – Matrix dimension for bounds checking
stride_bias – Stride for bias vector (default: 1)
- Returns:
Accumulator with bias added
Example:
acc = tile.bias(acc, bias_ptr, M)
- class GemmContext(block_m, block_n, block_k, num_sms, num_xcds=1, group_size_m=8, chunk_size=1, cache_modifier_a='.cg', cache_modifier_b='.cg', acc_dtype=tl.float32, allow_tf32=True, even_k=True, quantized=False)[source]#
Bases:
objectGEMM context with all configuration parameters and accumulator management.
Bundles together all compile-time GEMM parameters:
Block sizes (M, N, K)
Hardware configuration (NUM_SMS, NUM_XCDS)
Scheduling parameters (GROUP_SIZE_M, CHUNK_SIZE)
Cache modifiers
Computation options (acc_dtype, allow_tf32, even_k, quantized)
Provides two execution modes:
reduce_tile(): Single BLOCK_K iteration (one dot product)reduce_axis(): Full K loop
Tile Creation Convention
For A [M, K] and B [K, N]:
A tiles: (pid_m, k_idx) with shape (BLOCK_M, BLOCK_K)
B tiles: (k_idx, pid_n) with shape (BLOCK_K, BLOCK_N)
The
InputViewhandles the pointer arithmetic based on its stored layout.Example
tensorA = make_tensor_view(A, M, K, stride_am, stride_ak) tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn) ctx = GemmContext( block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS, num_xcds=NUM_XCDS, group_size_m=8, even_k=EVEN_K, ) # Use in ScheduleContext sched = ScheduleContext(M, N, K, ctx) acc = ctx.reduce_axis(tensorA, tensorB, out_tile)
- Parameters:
block_m (tl.constexpr)
block_n (tl.constexpr)
block_k (tl.constexpr)
num_sms (tl.constexpr)
num_xcds (tl.constexpr)
group_size_m (tl.constexpr)
chunk_size (tl.constexpr)
cache_modifier_a (tl.constexpr)
cache_modifier_b (tl.constexpr)
acc_dtype (tl.constexpr)
allow_tf32 (tl.constexpr)
even_k (tl.constexpr)
quantized (tl.constexpr)
- __init__(block_m, block_n, block_k, num_sms, num_xcds=1, group_size_m=8, chunk_size=1, cache_modifier_a='.cg', cache_modifier_b='.cg', acc_dtype=tl.float32, allow_tf32=True, even_k=True, quantized=False)[source]#
Create a GEMM context with all configuration parameters.
- Parameters:
block_m – Block size M (constexpr)
block_n – Block size N (constexpr)
block_k – Block size K (constexpr)
num_sms – Number of SMs/CUs (constexpr)
num_xcds – Number of XCDs for chiplet transform (default: 1)
group_size_m – Group size for tile scheduling (default: 8)
chunk_size – Chunk size for chiplet scheduling (default: 1)
cache_modifier_a – Cache modifier for A loads (default: “.cg”)
cache_modifier_b – Cache modifier for B loads (default: “.cg”)
acc_dtype – Accumulator dtype (default: tl.float32)
allow_tf32 – Allow TF32 for matmul (default: True)
even_k – Whether K is evenly divisible by BLOCK_K (default: True)
quantized – Use int32 accumulation for quantized inputs (default: False)
- block_m: tl.constexpr#
- block_n: tl.constexpr#
- block_k: tl.constexpr#
- num_sms: tl.constexpr#
- num_xcds: tl.constexpr#
- group_size_m: tl.constexpr#
- chunk_size: tl.constexpr#
- cache_modifier_a: tl.constexpr#
- cache_modifier_b: tl.constexpr#
- acc_dtype: tl.constexpr#
- allow_tf32: tl.constexpr#
- even_k: tl.constexpr#
- quantized: tl.constexpr#
- init_accumulator()[source]#
Initialize and return a zero accumulator.
- Returns:
Accumulator tensor [BLOCK_M, BLOCK_N] initialized to zeros
Example:
acc = ctx.init_accumulator()
- reduce_tile(A, B, out_tile, k_idx, acc, boundary=False)[source]#
Execute a single K step (one BLOCK_K iteration).
Creates tiles for A and B at the given K index and loads them using the InputView’s tile_ptrs method (which handles layout internally).
- Parameters:
A (InputView) – InputView for matrix A [M, K] with strides already stored
B (InputView) – InputView for matrix B [K, N] with strides already stored
out_tile (Tile) – Output Tile with (pid_m, pid_n, BLOCK_M, BLOCK_N)
k_idx – Current K tile index
acc – Accumulator to add to
boundary (tl.constexpr) – Whether this is a boundary iteration needing masking
- Returns:
Updated accumulator tensor [BLOCK_M, BLOCK_N]
Example:
A = make_tensor_view(A_ptr, M, K, stride_am, stride_ak) B = make_tensor_view(B_ptr, K, N, stride_bk, stride_bn) acc = ctx.init_accumulator() for k_idx in range(num_k_tiles): acc = ctx.reduce_tile(A, B, out_tile, k_idx, acc)
- reduce_axis(A, B, out_tile)[source]#
Execute the full GEMM K loop and return the accumulator.
Iterates over all K tiles, loading from A and B using their stored layout information, and accumulates the dot products.
- Parameters:
- Returns:
Accumulator tensor [BLOCK_M, BLOCK_N]
Example:
A = make_tensor_view(A_ptr, M, K, stride_am, stride_ak) B = make_tensor_view(B_ptr, K, N, stride_bk, stride_bn) ctx = GemmContext(block_m=128, block_n=256, block_k=64, ...) acc = ctx.reduce_axis(A, B, out_tile)
- class ScheduleContext(M, N, K, ctx, streamk_tiles=0)[source]#
Bases:
objectUnified scheduling context that hides persistent GEMM loop complexity.
Two simple iteration patterns:
Tile loop:
for tile_id in range(start, total, stride)withget_tile(tile_id)Iter loop:
for iter_id in range(start, end)withget_iter(iter_id)
Example (persistent GEMM)#
ctx = GemmContext(block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS, num_xcds=NUM_XCDS) sched = ScheduleContext(M, N, K, ctx) start, total, stride = sched.persistent_tile_range() for tile_id in range(start, total, stride): out_tile = sched.get_tile_from_idx(tile_id) # Process full tile
Example (Stream-K)#
ctx = GemmContext(block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS) sched = ScheduleContext(M, N, K, ctx, streamk_tiles=STREAMK_TILES) start, end = sched.iter_range() for iter_id in range(start, end): pid_m, pid_n, k_iter = sched.get_iter(iter_id) # Process single K iteration at (pid_m, pid_n, k_iter)
- __init__(M, N, K, ctx, streamk_tiles=0)[source]#
Create a ScheduleContext from a GemmContext.
- Parameters:
M – Problem dimensions
N – Problem dimensions
K – Problem dimensions
ctx (GemmContext) – GemmContext with block sizes and scheduling parameters
streamk_tiles – Number of tiles for Stream-K (0 = persistent only)
- M: tl.tensor#
- N: tl.tensor#
- K: tl.tensor#
- ctx: GemmContext#
- streamk_tiles: tl.constexpr#
- persistent_tile_range()[source]#
Get tile iteration range for this workgroup (persistent GEMM).
- Returns:
Use as
range(start, total, stride)- Return type:
(start, total, stride)
Example:
start, total, stride = sched.persistent_tile_range() for tile_id in range(start, total, stride): pid_m, pid_n = sched.get_tile(tile_id) ...
- get_tile_from_idx(tile_id)[source]#
Get a Tile for a given tile ID.
- Parameters:
tile_id – Linear tile index
- Returns:
Tile object with computed coordinates and ctx block sizes
- Return type:
- get_tile_from_coord(pid_m, pid_n)[source]#
Get a Tile from 2D coordinates.
- Parameters:
pid_m – Tile coordinate in M dimension
pid_n – Tile coordinate in N dimension
- Returns:
Tile object with the given coordinates and ctx block sizes
- Return type:
- iter_range()[source]#
Get iteration range for this workgroup (Stream-K mode).
- Returns:
Iteration range [start, end)
- Return type:
(start_iter, end_iter)
- Parameters:
M (tl.tensor)
N (tl.tensor)
K (tl.tensor)
ctx (GemmContext)
streamk_tiles (tl.constexpr)
- class InputView(ptr, rows, cols, stride_row, stride_col)[source]#
Bases:
objectInput matrix view for GEMM.
Stores the matrix pointer, dimensions, and both strides. The
tile_ptrs()method computes pointers using the general formula that works for any memory layout.- Parameters:
ptr (tl.tensor)
rows (tl.tensor)
cols (tl.tensor)
stride_row (tl.tensor)
stride_col (tl.tensor)
- ptr#
Base pointer to matrix data
- Type:
tl.tensor
- rows#
Number of rows
- Type:
tl.tensor
- cols#
Number of columns
- Type:
tl.tensor
- stride_row#
Stride when moving along rows (first dimension)
- Type:
tl.tensor
- stride_col#
Stride when moving along columns (second dimension)
- Type:
tl.tensor
- ptr: tl.tensor#
- rows: tl.tensor#
- cols: tl.tensor#
- stride_row: tl.tensor#
- stride_col: tl.tensor#
- tile_ptrs(tile)[source]#
Compute pointer array and bounds mask for a tile.
Uses the general formula: ptr[i,j] = base + i*stride_row + j*stride_col This works for any memory layout (row-major, col-major, or other).
- Parameters:
tile (Tile) – Tile object with (pid_row, pid_col, block_row, block_col)
- Returns:
2D pointer array [BLOCK_ROW, BLOCK_COL] mask: 2D boolean mask for boundary handling
- Return type:
ptrs
- load(tile, boundary=False, cache_modifier='.cg')[source]#
Load a tile from this matrix.
- Parameters:
tile (Tile) – Tile with coordinates and shape
boundary (tl.constexpr) – If True, apply boundary masking for partial tiles
cache_modifier (tl.constexpr) – Cache modifier for load instruction
- Returns:
Loaded tile data [BLOCK_ROW, BLOCK_COL]
- class OutputView(ptr, rows, cols, stride_row, stride_col)[source]#
Bases:
objectOutput matrix view for GEMM.
Same design as
InputView- stores pointer, dimensions, and strides. Providesstore()with optional epilogue (scaling, bias, type conversion).The
store()method can optionally apply:Example
tensorC = make_output_view(C, M, N, stride_cm, stride_cn) scale_view = make_scale_view(A_scale, B_scale, M, N) bias_view = make_bias_view(bias, M, stride_bias) # Full epilogue: scale -> bias -> convert -> store tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
- Parameters:
ptr (tl.tensor)
rows (tl.tensor)
cols (tl.tensor)
stride_row (tl.tensor)
stride_col (tl.tensor)
- ptr: tl.tensor#
- rows: tl.tensor#
- cols: tl.tensor#
- stride_row: tl.tensor#
- stride_col: tl.tensor#
- store(data, tile, mask=None, scale=None, bias=None)[source]#
Store data to a tile with optional epilogue operations.
Applies epilogue in order: scale -> bias -> type convert -> store
- Parameters:
Example:
# Simple store (no epilogue) tensorC.store(acc.to(C.type.element_ty), out_tile) # With full epilogue tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
- class ScaleView(a_scale_ptr, b_scale_ptr, M, N, stride_a, stride_b)[source]#
Bases:
objectScale vectors view for quantized GEMM epilogue.
Stores pointers to per-row A scales and per-column B scales, along with dimensions for bounds checking.
- Parameters:
a_scale_ptr (tl.tensor)
b_scale_ptr (tl.tensor)
M (tl.tensor)
N (tl.tensor)
stride_a (tl.tensor)
stride_b (tl.tensor)
- a_scale_ptr#
Pointer to A scale vector (per-row, length M)
- Type:
tl.tensor
- b_scale_ptr#
Pointer to B scale vector (per-column, length N)
- Type:
tl.tensor
- M#
Number of rows (for A scale bounds)
- Type:
tl.tensor
- N#
Number of columns (for B scale bounds)
- Type:
tl.tensor
- stride_a#
Stride for A scales (default: 1)
- Type:
tl.tensor
- stride_b#
Stride for B scales (default: 1)
- Type:
tl.tensor
- a_scale_ptr: tl.tensor#
- b_scale_ptr: tl.tensor#
- M: tl.tensor#
- N: tl.tensor#
- stride_a: tl.tensor#
- stride_b: tl.tensor#
- class BiasView(ptr, N, stride)[source]#
Bases:
objectBias vector view for GEMM epilogue.
Stores pointer to bias vector and dimension for bounds checking.
- Parameters:
ptr (tl.tensor)
N (tl.tensor)
stride (tl.tensor)
- ptr#
Pointer to bias vector (length M, broadcast across columns)
- Type:
tl.tensor
- M#
Number of rows (for bounds checking)
- stride#
Stride for bias vector (default: 1)
- Type:
tl.tensor
- ptr: tl.tensor#
- N: tl.tensor#
- stride: tl.tensor#
- make_input_view(ptr, rows, cols, stride_row, stride_col)[source]#
Create an InputView with automatic stride type coercion.
This factory ensures strides are always tensor-typed, handling the case where contiguous dimensions have stride=1 (Python int) while other dimensions have tensor-typed strides.
- Parameters:
ptr – Base pointer to matrix data
rows – Number of rows (first dimension)
cols – Number of columns (second dimension)
stride_row – Stride when moving along rows
stride_col – Stride when moving along columns
- Returns:
InputView with all fields as tensors
Example:
# A is [M, K] matrix - strides can be int or tensor tensorA = make_input_view(A, M, K, stride_am, stride_ak) # B is [K, N] matrix tensorB = make_input_view(B, K, N, stride_bk, stride_bn)
- make_tensor_view(ptr, rows, cols, stride_row, stride_col)#
Create an InputView with automatic stride type coercion.
This factory ensures strides are always tensor-typed, handling the case where contiguous dimensions have stride=1 (Python int) while other dimensions have tensor-typed strides.
- Parameters:
ptr – Base pointer to matrix data
rows – Number of rows (first dimension)
cols – Number of columns (second dimension)
stride_row – Stride when moving along rows
stride_col – Stride when moving along columns
- Returns:
InputView with all fields as tensors
Example:
# A is [M, K] matrix - strides can be int or tensor tensorA = make_input_view(A, M, K, stride_am, stride_ak) # B is [K, N] matrix tensorB = make_input_view(B, K, N, stride_bk, stride_bn)
- make_output_view(ptr, rows, cols, stride_row, stride_col)[source]#
Create an OutputView with automatic stride type coercion.
Same as make_input_view() but returns an OutputView which has store() method in addition to load().
- Parameters:
ptr – Base pointer to matrix data
rows – Number of rows (first dimension)
cols – Number of columns (second dimension)
stride_row – Stride when moving along rows
stride_col – Stride when moving along columns
- Returns:
OutputView with all fields as tensors
Example:
# C is [M, N] output matrix tensorC = make_output_view(C, M, N, stride_cm, stride_cn)
- make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1)[source]#
Create a ScaleView for quantized GEMM epilogue.
Stores per-row A scales and per-column B scales with automatic stride type coercion.
- Parameters:
a_scale_ptr – Pointer to A scale vector (per-row, length M)
b_scale_ptr – Pointer to B scale vector (per-column, length N)
M – Number of rows (for A scale bounds) - must be a tensor
N – Number of columns (for B scale bounds)
stride_a – Stride for A scales (default: 1)
stride_b – Stride for B scales (default: 1)
- Returns:
ScaleView with all fields as tensors
Example:
scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N) tensorC.store(acc, out_tile, scale=scale_view)
- make_bias_view(bias_ptr, N, stride=1)[source]#
Create a BiasView for GEMM epilogue.
Stores bias vector pointer with automatic stride type coercion.
- Parameters:
bias_ptr – Pointer to bias vector (length N)
N – Number of columns (for bounds checking)
stride – Stride for bias vector (default: 1)
- Returns:
BiasView with all fields as tensors
Example:
bias_view = make_bias_view(bias_ptr, N, stride_bias) tensorC.store(acc, out_tile, bias=bias_view)