Source code for tritonblas.kernels.stages.schedule

# SPDX-License-Identifier: MIT
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

"""
ScheduleContext aggregate for tritonblas shards.

Provides a simple iterator interface that hides the complexity of persistent
GEMM and Stream-K scheduling. Just call next_tile() or next_iter() to get
the work unit coordinates.
"""

import triton
import triton.language as tl
from triton.language.core import _aggregate as aggregate

from .grid import chiplet_transform_chunked
from .gemm_context import GemmContext
from .tile import Tile


[docs] @aggregate class ScheduleContext: """ Unified scheduling context that hides persistent GEMM loop complexity. Two simple iteration patterns: * **Tile loop**: ``for tile_id in range(start, total, stride)`` with ``get_tile(tile_id)`` * **Iter loop**: ``for iter_id in range(start, end)`` with ``get_iter(iter_id)`` Example (persistent GEMM) ------------------------- .. code-block:: python ctx = GemmContext(block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS, num_xcds=NUM_XCDS) sched = ScheduleContext(M, N, K, ctx) start, total, stride = sched.persistent_tile_range() for tile_id in range(start, total, stride): out_tile = sched.get_tile_from_idx(tile_id) # Process full tile Example (Stream-K) ------------------ .. code-block:: python ctx = GemmContext(block_m=128, block_n=256, block_k=64, num_sms=NUM_SMS) sched = ScheduleContext(M, N, K, ctx, streamk_tiles=STREAMK_TILES) start, end = sched.iter_range() for iter_id in range(start, end): pid_m, pid_n, k_iter = sched.get_iter(iter_id) # Process single K iteration at (pid_m, pid_n, k_iter) """ # Problem dimensions M: tl.tensor N: tl.tensor K: tl.tensor # GemmContext with all block sizes and scheduling params ctx: GemmContext # Stream-K specific streamk_tiles: tl.constexpr
[docs] @triton.constexpr_function def __init__( self, M, N, K, ctx: GemmContext, streamk_tiles=0, ): """ Create a ScheduleContext from a GemmContext. Args: M, N, K: Problem dimensions ctx: GemmContext with block sizes and scheduling parameters streamk_tiles: Number of tiles for Stream-K (0 = persistent only) """ self.M = M self.N = N self.K = K self.ctx = ctx self.streamk_tiles = tl.constexpr(streamk_tiles)
# ================================================================ # Tile-level iteration (for persistent GEMM) # ================================================================
[docs] @triton.jit def persistent_tile_range(self): """ Get tile iteration range for this workgroup (persistent GEMM). Returns: (start, total, stride): Use as ``range(start, total, stride)`` Example:: start, total, stride = sched.persistent_tile_range() for tile_id in range(start, total, stride): pid_m, pid_n = sched.get_tile(tile_id) ... """ num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) total_tiles = num_pid_m * num_pid_n # Get transformed program ID pid = tl.program_id(0) if self.ctx.num_xcds != 1: pid = chiplet_transform_chunked(pid, self.ctx.num_sms, self.ctx.num_xcds, self.ctx.chunk_size) return pid, total_tiles, self.ctx.num_sms
[docs] @triton.jit def get_tile_from_idx(self, tile_id): """ Get a Tile for a given tile ID. Args: tile_id: Linear tile index Returns: Tile: Tile object with computed coordinates and ctx block sizes """ num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) num_pid_in_group = self.ctx.group_size_m * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * self.ctx.group_size_m group_size_m = tl.minimum(num_pid_m - first_pid_m, self.ctx.group_size_m) pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) return Tile(pid_m, pid_n, self.ctx.block_m, self.ctx.block_n)
[docs] @triton.jit def get_tile_from_coord(self, pid_m, pid_n): """ Get a Tile from 2D coordinates. Args: pid_m: Tile coordinate in M dimension pid_n: Tile coordinate in N dimension Returns: Tile: Tile object with the given coordinates and ctx block sizes """ return Tile(pid_m, pid_n, self.ctx.block_m, self.ctx.block_n)
@triton.jit def _tile_idx_to_coord(self, tile_id): """ Internal: Convert tile ID to coordinates (returns tuple). Args: tile_id: Linear tile index Returns: (pid_m, pid_n): Tile coordinates as tuple """ num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) num_pid_in_group = self.ctx.group_size_m * num_pid_n group_id = tile_id // num_pid_in_group first_pid_m = group_id * self.ctx.group_size_m group_size_m = tl.minimum(num_pid_m - first_pid_m, self.ctx.group_size_m) pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) pid_n = (tile_id % num_pid_in_group) // group_size_m tl.assume(pid_m >= 0) tl.assume(pid_n >= 0) return pid_m, pid_n # ================================================================ # Iteration-level iteration (for Stream-K) # ================================================================
[docs] @triton.jit def iter_range(self): """ Get iteration range for this workgroup (Stream-K mode). Returns: (start_iter, end_iter): Iteration range [start, end) """ num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) total_tiles = num_pid_m * num_pid_n iters_per_tile = tl.cdiv(self.K, self.ctx.block_k) # Get transformed program ID pid = tl.program_id(0) if self.ctx.num_xcds != 1: pid = chiplet_transform_chunked(pid, self.ctx.num_sms, self.ctx.num_xcds, self.ctx.chunk_size) total_full_tiles = total_tiles - self.streamk_tiles total_streamk_iters = self.streamk_tiles * iters_per_tile streamk_iters_pcu = total_streamk_iters // self.ctx.num_sms streamk_remainder_iters = total_streamk_iters % self.ctx.num_sms start_iter = ( total_full_tiles * iters_per_tile + pid * streamk_iters_pcu + tl.minimum(pid, streamk_remainder_iters) ) end_iter = ( total_full_tiles * iters_per_tile + (pid + 1) * streamk_iters_pcu + tl.minimum(pid + 1, streamk_remainder_iters) ) return start_iter, end_iter
[docs] @triton.jit def get_iter(self, global_iter): """ Get coordinates for a given global iteration. Args: global_iter: Global iteration index Returns: (pid_m, pid_n, k_iter): Tile coordinates and K iteration index """ iters_per_tile = tl.cdiv(self.K, self.ctx.block_k) # Convert global iteration to (tile_id, k_iter) tile_id = global_iter // iters_per_tile k_iter = global_iter % iters_per_tile # Convert tile_id to (pid_m, pid_n) pid_m, pid_n = self._tile_idx_to_coord(tile_id) return pid_m, pid_n, k_iter
[docs] @triton.jit def iters_per_tile(self): """Number of K iterations per tile.""" return tl.cdiv(self.K, self.ctx.block_k)
[docs] @triton.jit def total_tiles(self): """Total number of tiles.""" num_pid_m = tl.cdiv(self.M, self.ctx.block_m) num_pid_n = tl.cdiv(self.N, self.ctx.block_n) return num_pid_m * num_pid_n
@triton.jit def make_schedule_context(M, N, K, ctx: GemmContext, streamk_tiles=0): """ Create a ScheduleContext from a GemmContext. Args: M, N, K: Problem dimensions ctx: GemmContext with block sizes and scheduling parameters streamk_tiles: Number of tiles for Stream-K (0 = persistent only) """ M_t = M + 0 * M N_t = N + 0 * M K_t = K + 0 * M return ScheduleContext(M_t, N_t, K_t, ctx, streamk_tiles)