Stages Reference

Contents

Stages Reference#

This documentation is automatically generated from source code docstrings.

The stages module provides composable device-side abstractions for building high-performance GEMM kernels in Triton.

Note

All stages APIs are device-side only — they execute within @triton.jit kernels on the GPU. They cannot be called from host Python code.

Type annotations shown as MagicMock represent Triton types (tl.tensor, tl.constexpr) that are mocked during documentation generation.

Module Reference#

Composable stage abstractions for tritonblas GEMM kernels.

Key abstractions:

  • GemmContext: Accumulator context with reduce_axis() plus all config parameters

  • ScheduleContext: Unified scheduling with persistent_tile_range()/get_tile()

  • InputView, OutputView: Matrix views with tile_ptrs() for memory access

  • ScaleView, BiasView: Epilogue views for quantized GEMM (scale and bias)

  • Tile: 2D tile with coordinates (pid_m, pid_n) and shape (block_m, block_n)

Factory functions:

Example

from tritonblas.kernels.stages import (
    ScheduleContext, GemmContext,
    make_tensor_view, make_output_view,
    make_scale_view, make_bias_view,
)

@triton.jit
def kernel(A, B, C, A_scale_ptr, B_scale_ptr, bias_ptr, M, N, K,
           stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
           stride_bias, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
           BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
           NUM_SMS: tl.constexpr, NUM_XCDS: tl.constexpr,
           EVEN_K: tl.constexpr, BIAS: tl.constexpr, ...):

    # Create matrix views - just describe your matrices
    tensorA = make_tensor_view(A, M, K, stride_am, stride_ak)
    tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn)
    tensorC = make_output_view(C, M, N, stride_cm, stride_cn)

    # Create epilogue views (optional scale and bias)
    scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N) if A_scale_ptr is not None else None
    bias_view = make_bias_view(bias_ptr, M, stride_bias) if BIAS else None

    # Construct GemmContext on device with ALL parameters
    ctx = GemmContext(
        BLOCK_M, BLOCK_N, BLOCK_K, NUM_SMS, NUM_XCDS, GROUP_SIZE_M,
        even_k=EVEN_K,
    )

    # Create schedule from GemmContext
    sched = ScheduleContext(M, N, K, ctx)

    # Persistent loop
    start, total, stride = sched.persistent_tile_range()
    for tile_id in range(start, total, stride):
        out_tile = sched.get_tile_from_idx(tile_id)

        # Compute GEMM
        acc = ctx.reduce_axis(tensorA, tensorB, out_tile)

        # Store with epilogue: scale -> bias -> convert -> store
        tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
class Tile(pid_m, pid_n, block_m, block_n)[source]#

Bases: object

2D tile with coordinates and shape.

Stores runtime coordinates (pid_m, pid_n) and compile-time block sizes.

Use ScheduleContext methods to create tiles:

# From linear tile_id
out_tile = sched.get_tile_from_idx(tile_id)

# From 2D coordinates
out_tile = sched.get_tile_from_coord(pid_m, pid_n)

# Direct construction is also supported
out_tile = Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)

Example

# Output tile from scheduler
out_tile = sched.get_tile_from_idx(tile_id)

# Input A tile at k offset
a_tile = sched.get_tile_from_coord(pid_m, k // BLOCK_K)

# Input B tile at k offset
b_tile = sched.get_tile_from_coord(k // BLOCK_K, pid_n)
Parameters:
  • pid_m (tl.tensor)

  • pid_n (tl.tensor)

  • block_m (tl.constexpr)

  • block_n (tl.constexpr)

__init__(pid_m, pid_n, block_m, block_n)[source]#

Create a tile with runtime coordinates and compile-time sizes.

Parameters:
  • pid_m – Tile coordinate in M dimension

  • pid_n – Tile coordinate in N dimension

  • block_m – Block size in M dimension (constexpr)

  • block_n – Block size in N dimension (constexpr)

pid_m: tl.tensor#
pid_n: tl.tensor#
block_m: tl.constexpr#
block_n: tl.constexpr#
indices()[source]#

Compute row and column indices for this tile.

Returns:

Row indices [BLOCK_M], column indices [BLOCK_N]

Return type:

rm, rn

layout(M, N)[source]#

Compute memory layout with bounds checking.

Parameters:
  • M – Total rows

  • N – Total columns

Returns:

Row indices, column indices, bounds mask

Return type:

rm, rn, mask

Note

The mask is computed from RAW indices before modulo wrapping. This is critical for correct boundary handling - e.g., when K doesn’t divide evenly by BLOCK_K, the boundary tile needs proper masking to avoid reading garbage data.

scale(acc, A_scale_ptr, B_scale_ptr, M, N, stride_a=1, stride_b=1)[source]#

Apply quantization scales to accumulator.

Parameters:
  • acc – Accumulator tensor [BLOCK_M, BLOCK_N]

  • A_scale_ptr – Pointer to A scales (per-row)

  • B_scale_ptr – Pointer to B scales (per-column)

  • M – Matrix dimensions for bounds checking

  • N – Matrix dimensions for bounds checking

  • stride_a – Stride for A scales (default: 1)

  • stride_b – Stride for B scales (default: 1)

Returns:

Scaled accumulator as float32

Example:

acc = tile.scale(acc, A_scale_ptr, B_scale_ptr, M, N)
bias(acc, bias_ptr, M, stride_bias=1)[source]#

Add bias vector to accumulator.

Parameters:
  • acc – Accumulator tensor [BLOCK_M, BLOCK_N]

  • bias_ptr – Pointer to bias vector

  • M – Matrix dimension for bounds checking

  • stride_bias – Stride for bias vector (default: 1)

Returns:

Accumulator with bias added

Example:

acc = tile.bias(acc, bias_ptr, M)
class GemmContext(block_m, block_n, block_k, num_sms, num_xcds=1, group_size_m=8, chunk_size=1, cache_modifier_a='.cg', cache_modifier_b='.cg', acc_dtype=tl.float32, allow_tf32=True, even_k=True, quantized=False)[source]#

Bases: object

GEMM context with all configuration parameters and accumulator management.

Bundles together all compile-time GEMM parameters:

  • Block sizes (M, N, K)

  • Hardware configuration (NUM_SMS, NUM_XCDS)

  • Scheduling parameters (GROUP_SIZE_M, CHUNK_SIZE)

  • Cache modifiers

  • Computation options (acc_dtype, allow_tf32, even_k, quantized)

Provides two execution modes:

  • reduce_tile(): Single BLOCK_K iteration (one dot product)

  • reduce_axis(): Full K loop

Tile Creation Convention

For A [M, K] and B [K, N]:

  • A tiles: (pid_m, k_idx) with shape (BLOCK_M, BLOCK_K)

  • B tiles: (k_idx, pid_n) with shape (BLOCK_K, BLOCK_N)

The InputView handles the pointer arithmetic based on its stored layout.

Example

tensorA = make_tensor_view(A, M, K, stride_am, stride_ak)
tensorB = make_tensor_view(B, K, N, stride_bk, stride_bn)

ctx = GemmContext(
    block_m=128, block_n=256, block_k=64,
    num_sms=NUM_SMS, num_xcds=NUM_XCDS,
    group_size_m=8, even_k=EVEN_K,
)

# Use in ScheduleContext
sched = ScheduleContext(M, N, K, ctx)

acc = ctx.reduce_axis(tensorA, tensorB, out_tile)
Parameters:
  • block_m (tl.constexpr)

  • block_n (tl.constexpr)

  • block_k (tl.constexpr)

  • num_sms (tl.constexpr)

  • num_xcds (tl.constexpr)

  • group_size_m (tl.constexpr)

  • chunk_size (tl.constexpr)

  • cache_modifier_a (tl.constexpr)

  • cache_modifier_b (tl.constexpr)

  • acc_dtype (tl.constexpr)

  • allow_tf32 (tl.constexpr)

  • even_k (tl.constexpr)

  • quantized (tl.constexpr)

__init__(block_m, block_n, block_k, num_sms, num_xcds=1, group_size_m=8, chunk_size=1, cache_modifier_a='.cg', cache_modifier_b='.cg', acc_dtype=tl.float32, allow_tf32=True, even_k=True, quantized=False)[source]#

Create a GEMM context with all configuration parameters.

Parameters:
  • block_m – Block size M (constexpr)

  • block_n – Block size N (constexpr)

  • block_k – Block size K (constexpr)

  • num_sms – Number of SMs/CUs (constexpr)

  • num_xcds – Number of XCDs for chiplet transform (default: 1)

  • group_size_m – Group size for tile scheduling (default: 8)

  • chunk_size – Chunk size for chiplet scheduling (default: 1)

  • cache_modifier_a – Cache modifier for A loads (default: “.cg”)

  • cache_modifier_b – Cache modifier for B loads (default: “.cg”)

  • acc_dtype – Accumulator dtype (default: tl.float32)

  • allow_tf32 – Allow TF32 for matmul (default: True)

  • even_k – Whether K is evenly divisible by BLOCK_K (default: True)

  • quantized – Use int32 accumulation for quantized inputs (default: False)

block_m: tl.constexpr#
block_n: tl.constexpr#
block_k: tl.constexpr#
num_sms: tl.constexpr#
num_xcds: tl.constexpr#
group_size_m: tl.constexpr#
chunk_size: tl.constexpr#
cache_modifier_a: tl.constexpr#
cache_modifier_b: tl.constexpr#
acc_dtype: tl.constexpr#
allow_tf32: tl.constexpr#
even_k: tl.constexpr#
quantized: tl.constexpr#
init_accumulator()[source]#

Initialize and return a zero accumulator.

Returns:

Accumulator tensor [BLOCK_M, BLOCK_N] initialized to zeros

Example:

acc = ctx.init_accumulator()
reduce_tile(A, B, out_tile, k_idx, acc, boundary=False)[source]#

Execute a single K step (one BLOCK_K iteration).

Creates tiles for A and B at the given K index and loads them using the InputView’s tile_ptrs method (which handles layout internally).

Parameters:
  • A (InputView) – InputView for matrix A [M, K] with strides already stored

  • B (InputView) – InputView for matrix B [K, N] with strides already stored

  • out_tile (Tile) – Output Tile with (pid_m, pid_n, BLOCK_M, BLOCK_N)

  • k_idx – Current K tile index

  • acc – Accumulator to add to

  • boundary (tl.constexpr) – Whether this is a boundary iteration needing masking

Returns:

Updated accumulator tensor [BLOCK_M, BLOCK_N]

Example:

A = make_tensor_view(A_ptr, M, K, stride_am, stride_ak)
B = make_tensor_view(B_ptr, K, N, stride_bk, stride_bn)
acc = ctx.init_accumulator()
for k_idx in range(num_k_tiles):
    acc = ctx.reduce_tile(A, B, out_tile, k_idx, acc)
reduce_axis(A, B, out_tile)[source]#

Execute the full GEMM K loop and return the accumulator.

Iterates over all K tiles, loading from A and B using their stored layout information, and accumulates the dot products.

Parameters:
  • A (InputView) – InputView for matrix A [M, K] with strides already stored

  • B (InputView) – InputView for matrix B [K, N] with strides already stored

  • out_tile (Tile) – Output Tile with (pid_m, pid_n, BLOCK_M, BLOCK_N)

Returns:

Accumulator tensor [BLOCK_M, BLOCK_N]

Example:

A = make_tensor_view(A_ptr, M, K, stride_am, stride_ak)
B = make_tensor_view(B_ptr, K, N, stride_bk, stride_bn)
ctx = GemmContext(block_m=128, block_n=256, block_k=64, ...)
acc = ctx.reduce_axis(A, B, out_tile)
class ScheduleContext(M, N, K, ctx, streamk_tiles=0)[source]#

Bases: object

Unified scheduling context that hides persistent GEMM loop complexity.

Two simple iteration patterns:

  • Tile loop: for tile_id in range(start, total, stride) with get_tile(tile_id)

  • Iter loop: for iter_id in range(start, end) with get_iter(iter_id)

Example (persistent GEMM)#

ctx = GemmContext(block_m=128, block_n=256, block_k=64,
                  num_sms=NUM_SMS, num_xcds=NUM_XCDS)
sched = ScheduleContext(M, N, K, ctx)

start, total, stride = sched.persistent_tile_range()
for tile_id in range(start, total, stride):
    out_tile = sched.get_tile_from_idx(tile_id)
    # Process full tile

Example (Stream-K)#

ctx = GemmContext(block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS)
sched = ScheduleContext(M, N, K, ctx, streamk_tiles=STREAMK_TILES)

start, end = sched.iter_range()
for iter_id in range(start, end):
    pid_m, pid_n, k_iter = sched.get_iter(iter_id)
    # Process single K iteration at (pid_m, pid_n, k_iter)
__init__(M, N, K, ctx, streamk_tiles=0)[source]#

Create a ScheduleContext from a GemmContext.

Parameters:
  • M – Problem dimensions

  • N – Problem dimensions

  • K – Problem dimensions

  • ctx (GemmContext) – GemmContext with block sizes and scheduling parameters

  • streamk_tiles – Number of tiles for Stream-K (0 = persistent only)

M: tl.tensor#
N: tl.tensor#
K: tl.tensor#
ctx: GemmContext#
streamk_tiles: tl.constexpr#
persistent_tile_range()[source]#

Get tile iteration range for this workgroup (persistent GEMM).

Returns:

Use as range(start, total, stride)

Return type:

(start, total, stride)

Example:

start, total, stride = sched.persistent_tile_range()
for tile_id in range(start, total, stride):
    pid_m, pid_n = sched.get_tile(tile_id)
    ...
get_tile_from_idx(tile_id)[source]#

Get a Tile for a given tile ID.

Parameters:

tile_id – Linear tile index

Returns:

Tile object with computed coordinates and ctx block sizes

Return type:

Tile

get_tile_from_coord(pid_m, pid_n)[source]#

Get a Tile from 2D coordinates.

Parameters:
  • pid_m – Tile coordinate in M dimension

  • pid_n – Tile coordinate in N dimension

Returns:

Tile object with the given coordinates and ctx block sizes

Return type:

Tile

iter_range()[source]#

Get iteration range for this workgroup (Stream-K mode).

Returns:

Iteration range [start, end)

Return type:

(start_iter, end_iter)

get_iter(global_iter)[source]#

Get coordinates for a given global iteration.

Parameters:

global_iter – Global iteration index

Returns:

Tile coordinates and K iteration index

Return type:

(pid_m, pid_n, k_iter)

iters_per_tile()[source]#

Number of K iterations per tile.

total_tiles()[source]#

Total number of tiles.

Parameters:
  • M (tl.tensor)

  • N (tl.tensor)

  • K (tl.tensor)

  • ctx (GemmContext)

  • streamk_tiles (tl.constexpr)

class InputView(ptr, rows, cols, stride_row, stride_col)[source]#

Bases: object

Input matrix view for GEMM.

Stores the matrix pointer, dimensions, and both strides. The tile_ptrs() method computes pointers using the general formula that works for any memory layout.

Parameters:
  • ptr (tl.tensor)

  • rows (tl.tensor)

  • cols (tl.tensor)

  • stride_row (tl.tensor)

  • stride_col (tl.tensor)

ptr#

Base pointer to matrix data

Type:

tl.tensor

rows#

Number of rows

Type:

tl.tensor

cols#

Number of columns

Type:

tl.tensor

stride_row#

Stride when moving along rows (first dimension)

Type:

tl.tensor

stride_col#

Stride when moving along columns (second dimension)

Type:

tl.tensor

__init__(ptr, rows, cols, stride_row, stride_col)[source]#
ptr: tl.tensor#
rows: tl.tensor#
cols: tl.tensor#
stride_row: tl.tensor#
stride_col: tl.tensor#
tile_ptrs(tile)[source]#

Compute pointer array and bounds mask for a tile.

Uses the general formula: ptr[i,j] = base + i*stride_row + j*stride_col This works for any memory layout (row-major, col-major, or other).

Parameters:

tile (Tile) – Tile object with (pid_row, pid_col, block_row, block_col)

Returns:

2D pointer array [BLOCK_ROW, BLOCK_COL] mask: 2D boolean mask for boundary handling

Return type:

ptrs

load(tile, boundary=False, cache_modifier='.cg')[source]#

Load a tile from this matrix.

Parameters:
  • tile (Tile) – Tile with coordinates and shape

  • boundary (tl.constexpr) – If True, apply boundary masking for partial tiles

  • cache_modifier (tl.constexpr) – Cache modifier for load instruction

Returns:

Loaded tile data [BLOCK_ROW, BLOCK_COL]

class OutputView(ptr, rows, cols, stride_row, stride_col)[source]#

Bases: object

Output matrix view for GEMM.

Same design as InputView - stores pointer, dimensions, and strides. Provides store() with optional epilogue (scaling, bias, type conversion).

The store() method can optionally apply:

  • Quantization scales (from ScaleView)

  • Bias addition (from BiasView)

  • Type conversion to output dtype

Example

tensorC = make_output_view(C, M, N, stride_cm, stride_cn)
scale_view = make_scale_view(A_scale, B_scale, M, N)
bias_view = make_bias_view(bias, M, stride_bias)

# Full epilogue: scale -> bias -> convert -> store
tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
Parameters:
  • ptr (tl.tensor)

  • rows (tl.tensor)

  • cols (tl.tensor)

  • stride_row (tl.tensor)

  • stride_col (tl.tensor)

__init__(ptr, rows, cols, stride_row, stride_col)[source]#
ptr: tl.tensor#
rows: tl.tensor#
cols: tl.tensor#
stride_row: tl.tensor#
stride_col: tl.tensor#
tile_ptrs(tile)[source]#

Compute pointer array and bounds mask for a tile.

Parameters:

tile (Tile)

store(data, tile, mask=None, scale=None, bias=None)[source]#

Store data to a tile with optional epilogue operations.

Applies epilogue in order: scale -> bias -> type convert -> store

Parameters:
  • data – Data to store [BLOCK_ROW, BLOCK_COL]

  • tile (Tile) – Tile with coordinates and shape

  • mask – Optional mask (if None, computes from bounds)

  • scale (ScaleView) – Optional ScaleView for quantization scaling

  • bias (BiasView) – Optional BiasView for bias addition

Example:

# Simple store (no epilogue)
tensorC.store(acc.to(C.type.element_ty), out_tile)

# With full epilogue
tensorC.store(acc, out_tile, scale=scale_view, bias=bias_view)
load(tile, boundary=False, cache_modifier='.cg')[source]#

Load a tile from this matrix (for read-modify-write patterns).

Parameters:
  • tile (Tile)

  • boundary (tl.constexpr)

  • cache_modifier (tl.constexpr)

class ScaleView(a_scale_ptr, b_scale_ptr, M, N, stride_a, stride_b)[source]#

Bases: object

Scale vectors view for quantized GEMM epilogue.

Stores pointers to per-row A scales and per-column B scales, along with dimensions for bounds checking.

Parameters:
  • a_scale_ptr (tl.tensor)

  • b_scale_ptr (tl.tensor)

  • M (tl.tensor)

  • N (tl.tensor)

  • stride_a (tl.tensor)

  • stride_b (tl.tensor)

a_scale_ptr#

Pointer to A scale vector (per-row, length M)

Type:

tl.tensor

b_scale_ptr#

Pointer to B scale vector (per-column, length N)

Type:

tl.tensor

M#

Number of rows (for A scale bounds)

Type:

tl.tensor

N#

Number of columns (for B scale bounds)

Type:

tl.tensor

stride_a#

Stride for A scales (default: 1)

Type:

tl.tensor

stride_b#

Stride for B scales (default: 1)

Type:

tl.tensor

__init__(a_scale_ptr, b_scale_ptr, M, N, stride_a, stride_b)[source]#
a_scale_ptr: tl.tensor#
b_scale_ptr: tl.tensor#
M: tl.tensor#
N: tl.tensor#
stride_a: tl.tensor#
stride_b: tl.tensor#
apply(acc, tile)[source]#

Apply quantization scales to accumulator.

Parameters:
  • acc – Accumulator tensor [BLOCK_M, BLOCK_N]

  • tile (Tile) – Tile with coordinates for indexing

Returns:

Scaled accumulator as float32

class BiasView(ptr, N, stride)[source]#

Bases: object

Bias vector view for GEMM epilogue.

Stores pointer to bias vector and dimension for bounds checking.

Parameters:
  • ptr (tl.tensor)

  • N (tl.tensor)

  • stride (tl.tensor)

ptr#

Pointer to bias vector (length M, broadcast across columns)

Type:

tl.tensor

M#

Number of rows (for bounds checking)

stride#

Stride for bias vector (default: 1)

Type:

tl.tensor

__init__(ptr, N, stride)[source]#
ptr: tl.tensor#
N: tl.tensor#
stride: tl.tensor#
apply(acc, tile)[source]#

Add bias vector to accumulator.

Parameters:
  • acc – Accumulator tensor [BLOCK_M, BLOCK_N]

  • tile (Tile) – Tile with coordinates for indexing

Returns:

Accumulator with bias added

make_input_view(ptr, rows, cols, stride_row, stride_col)[source]#

Create an InputView with automatic stride type coercion.

This factory ensures strides are always tensor-typed, handling the case where contiguous dimensions have stride=1 (Python int) while other dimensions have tensor-typed strides.

Parameters:
  • ptr – Base pointer to matrix data

  • rows – Number of rows (first dimension)

  • cols – Number of columns (second dimension)

  • stride_row – Stride when moving along rows

  • stride_col – Stride when moving along columns

Returns:

InputView with all fields as tensors

Example:

# A is [M, K] matrix - strides can be int or tensor
tensorA = make_input_view(A, M, K, stride_am, stride_ak)

# B is [K, N] matrix
tensorB = make_input_view(B, K, N, stride_bk, stride_bn)
make_tensor_view(ptr, rows, cols, stride_row, stride_col)#

Create an InputView with automatic stride type coercion.

This factory ensures strides are always tensor-typed, handling the case where contiguous dimensions have stride=1 (Python int) while other dimensions have tensor-typed strides.

Parameters:
  • ptr – Base pointer to matrix data

  • rows – Number of rows (first dimension)

  • cols – Number of columns (second dimension)

  • stride_row – Stride when moving along rows

  • stride_col – Stride when moving along columns

Returns:

InputView with all fields as tensors

Example:

# A is [M, K] matrix - strides can be int or tensor
tensorA = make_input_view(A, M, K, stride_am, stride_ak)

# B is [K, N] matrix
tensorB = make_input_view(B, K, N, stride_bk, stride_bn)
make_output_view(ptr, rows, cols, stride_row, stride_col)[source]#

Create an OutputView with automatic stride type coercion.

Same as make_input_view() but returns an OutputView which has store() method in addition to load().

Parameters:
  • ptr – Base pointer to matrix data

  • rows – Number of rows (first dimension)

  • cols – Number of columns (second dimension)

  • stride_row – Stride when moving along rows

  • stride_col – Stride when moving along columns

Returns:

OutputView with all fields as tensors

Example:

# C is [M, N] output matrix
tensorC = make_output_view(C, M, N, stride_cm, stride_cn)
make_scale_view(a_scale_ptr, b_scale_ptr, M, N, stride_a=1, stride_b=1)[source]#

Create a ScaleView for quantized GEMM epilogue.

Stores per-row A scales and per-column B scales with automatic stride type coercion.

Parameters:
  • a_scale_ptr – Pointer to A scale vector (per-row, length M)

  • b_scale_ptr – Pointer to B scale vector (per-column, length N)

  • M – Number of rows (for A scale bounds) - must be a tensor

  • N – Number of columns (for B scale bounds)

  • stride_a – Stride for A scales (default: 1)

  • stride_b – Stride for B scales (default: 1)

Returns:

ScaleView with all fields as tensors

Example:

scale_view = make_scale_view(A_scale_ptr, B_scale_ptr, M, N)
tensorC.store(acc, out_tile, scale=scale_view)
make_bias_view(bias_ptr, N, stride=1)[source]#

Create a BiasView for GEMM epilogue.

Stores bias vector pointer with automatic stride type coercion.

Parameters:
  • bias_ptr – Pointer to bias vector (length N)

  • N – Number of columns (for bounds checking)

  • stride – Stride for bias vector (default: 1)

Returns:

BiasView with all fields as tensors

Example:

bias_view = make_bias_view(bias_ptr, N, stride_bias)
tensorC.store(acc, out_tile, bias=bias_view)
chiplet_transform(pid, num_workgroups, num_xcds)[source]#

Transform PID for basic chiplet-aware mapping.

Parameters:
  • num_workgroups (tl.constexpr)

  • num_xcds (tl.constexpr)

chiplet_transform_chunked(pid, num_workgroups, num_xcds, chunk_size)[source]#

Transform PID for chunked chiplet-aware mapping.

Parameters:
  • num_workgroups (tl.constexpr)

  • num_xcds (tl.constexpr)

  • chunk_size (tl.constexpr)