# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
"""
Tile aggregate for tritonblas shards.
"""
import triton
import triton.language as tl
from triton.language.core import _aggregate as aggregate
[docs]
@aggregate
class Tile:
"""
2D tile with coordinates and shape.
Stores runtime coordinates (pid_m, pid_n) and compile-time block sizes.
Use :class:`ScheduleContext` methods to create tiles::
# From linear tile_id
out_tile = sched.get_tile_from_idx(tile_id)
# From 2D coordinates
out_tile = sched.get_tile_from_coord(pid_m, pid_n)
# Direct construction is also supported
out_tile = Tile(pid_m, pid_n, BLOCK_M, BLOCK_N)
Example
-------
.. code-block:: python
# Output tile from scheduler
out_tile = sched.get_tile_from_idx(tile_id)
# Input A tile at k offset
a_tile = sched.get_tile_from_coord(pid_m, k // BLOCK_K)
# Input B tile at k offset
b_tile = sched.get_tile_from_coord(k // BLOCK_K, pid_n)
"""
pid_m: tl.tensor # Tile coordinate in M dimension
pid_n: tl.tensor # Tile coordinate in N dimension
block_m: tl.constexpr # Block size M
block_n: tl.constexpr # Block size N
[docs]
@triton.constexpr_function
def __init__(self, pid_m, pid_n, block_m, block_n):
"""
Create a tile with runtime coordinates and compile-time sizes.
Args:
pid_m: Tile coordinate in M dimension
pid_n: Tile coordinate in N dimension
block_m: Block size in M dimension (constexpr)
block_n: Block size in N dimension (constexpr)
"""
self.pid_m = pid_m
self.pid_n = pid_n
self.block_m = tl.constexpr(block_m)
self.block_n = tl.constexpr(block_n)
[docs]
@triton.jit
def indices(self):
"""
Compute row and column indices for this tile.
Returns:
rm, rn: Row indices [BLOCK_M], column indices [BLOCK_N]
"""
rm = self.pid_m * self.block_m + tl.arange(0, self.block_m)
rn = self.pid_n * self.block_n + tl.arange(0, self.block_n)
return rm, rn
[docs]
@triton.jit
def layout(self, M, N):
"""
Compute memory layout with bounds checking.
Args:
M: Total rows
N: Total columns
Returns:
rm, rn, mask: Row indices, column indices, bounds mask
Note:
The mask is computed from RAW indices before modulo wrapping.
This is critical for correct boundary handling - e.g., when K
doesn't divide evenly by BLOCK_K, the boundary tile needs proper
masking to avoid reading garbage data.
"""
rm, rn = self.indices()
# ═══════════════════════════════════════════════════════════════════
# MASK COMPUTATION: Use raw indices BEFORE modulo wrapping
# ═══════════════════════════════════════════════════════════════════
# The mask must be computed from the original indices to correctly
# identify out-of-bounds elements. After modulo, all indices would
# be < M and < N, making the mask useless.
mask = (rm[:, None] < M) & (rn[None, :] < N)
# ═══════════════════════════════════════════════════════════════════
# INDEX WRAPPING: Apply modulo for pointer computation
# ═══════════════════════════════════════════════════════════════════
# The modulo + max_contiguous optimization helps with memory access
# patterns, but must come AFTER mask computation.
rm = tl.max_contiguous(tl.multiple_of(rm % M, self.block_m), self.block_m)
rn = tl.max_contiguous(tl.multiple_of(rn % N, self.block_n), self.block_n)
return rm, rn, mask
[docs]
@triton.jit
def scale(self, acc, A_scale_ptr, B_scale_ptr, M, N, stride_a=1, stride_b=1):
"""
Apply quantization scales to accumulator.
Args:
acc: Accumulator tensor [BLOCK_M, BLOCK_N]
A_scale_ptr: Pointer to A scales (per-row)
B_scale_ptr: Pointer to B scales (per-column)
M, N: Matrix dimensions for bounds checking
stride_a: Stride for A scales (default: 1)
stride_b: Stride for B scales (default: 1)
Returns:
Scaled accumulator as float32
Example::
acc = tile.scale(acc, A_scale_ptr, B_scale_ptr, M, N)
"""
rm, rn = self.indices()
a_scales = tl.load(A_scale_ptr + rm * stride_a, mask=rm < M, other=1.0)
b_scales = tl.load(B_scale_ptr + rn * stride_b, mask=rn < N, other=1.0)
acc = acc.to(tl.float32)
acc = acc * a_scales[:, None]
acc = acc * b_scales[None, :]
return acc
[docs]
@triton.jit
def bias(self, acc, bias_ptr, M, stride_bias=1):
"""
Add bias vector to accumulator.
Args:
acc: Accumulator tensor [BLOCK_M, BLOCK_N]
bias_ptr: Pointer to bias vector
M: Matrix dimension for bounds checking
stride_bias: Stride for bias vector (default: 1)
Returns:
Accumulator with bias added
Example::
acc = tile.bias(acc, bias_ptr, M)
"""
rm, _ = self.indices()
bias_vector = tl.load(bias_ptr + rm * stride_bias, mask=rm < M, other=0.0)
acc = acc + bias_vector[:, None]
return acc