Source code for tritonblas.origami

from __future__ import annotations
import itertools
import torch
import origami
import math
from math import ceil


def estimate_triton_lds_bytes(
    block_m: int,
    block_n: int,
    block_k: int,
    bytes_a: float,
    bytes_b: float,
    num_stages: int = 2,
) -> float:
    """
    Estimate Triton kernel LDS (shared memory) usage in bytes for AMD GPUs.

    Triton's AMD backend uses swizzled_shared / amd_rotating_shared encodings
    which rearrange bank addressing without adding padding bytes.  The LDS
    footprint is therefore the raw tile bytes times the number of pipeline
    buffers:

      ns == 1:  max(A_bytes, B_bytes)   — no pipelining, sequential alloc
      ns >= 2:  (ns - 1) * (A_bytes + B_bytes)  — software-pipelined

    Validated against metadata.shared from compiled Triton kernels on gfx942
    (Triton 3.6.0+rocm7.2.0): 35/35 configs matched exactly.

    Args:
        block_m, block_n, block_k: Tile dimensions (MT_M, MT_N, MT_K).
        bytes_a, bytes_b: Bytes per element for A and B (e.g. 2 for bf16/fp16).
        num_stages: Pipeline stages (1, 2, or 3); Triton matmul uses 2 by default.

    Returns:
        Estimated total LDS usage in bytes.
    """
    a_bytes = block_m * block_k * bytes_a
    b_bytes = block_k * block_n * bytes_b
    if num_stages <= 1:
        return max(a_bytes, b_bytes)
    return (num_stages - 1) * (a_bytes + b_bytes)


def check_triton_lds_capacity(
    block_m: int,
    block_n: int,
    block_k: int,
    bytes_a: float,
    bytes_b: float,
    lds_capacity: int,
    num_stages: int = 2,
) -> bool:
    """Return True if estimated Triton LDS usage fits within lds_capacity."""
    usage = estimate_triton_lds_bytes(
        block_m, block_n, block_k, bytes_a, bytes_b, num_stages
    )
    return usage <= lds_capacity


[docs] class OrigamiMatmulSelector:
[docs] @staticmethod def estimate_triton_lds( block_m: int, block_n: int, block_k: int, bytes_a: float, bytes_b: float, num_stages: int = 2, ) -> float: """Class-level wrapper for estimate_triton_lds_bytes.""" return estimate_triton_lds_bytes( block_m, block_n, block_k, bytes_a, bytes_b, num_stages )
# https://docs.pytorch.org/docs/stable/tensors.html dtype_to_str = { torch.float32: "f32", torch.complex64: "c32", torch.complex128: "c64", torch.float64: "f64", torch.float16: "f16", torch.int32: "i32", torch.bfloat16: "bf16", torch.int8: "i8", torch.float8_e5m2: "f8", torch.float8_e4m3fn: "f8", } # Add FP8 FNUZ variants if available (for non-gfx950 architectures) if hasattr(torch, "float8_e5m2fnuz"): dtype_to_str[torch.float8_e5m2fnuz] = "f8" if hasattr(torch, "float8_e4m3fnuz"): dtype_to_str[torch.float8_e4m3fnuz] = "f8" COUNTERS_PER_XCD = 4 # work-stealing: default, overridden by _select_ws_params()
[docs] def __init__( self, m: int, n: int, k: int, a_dtype: torch.dtype, b_dtype: torch.dtype, out_dtype: torch.dtype, device: torch.device, mx_block_size=0, streamk=False, total_cus: int = None, active_cus: int = None, num_stages: int = 2, ): # Save tensor sizes self._m = m self._n = n self._k = k self.streamk = streamk self._num_stages = num_stages # Save tensor dtypes as strings self._a_dtype_str = OrigamiMatmulSelector.dtype_to_str.get(a_dtype, a_dtype) self._b_dtype_str = OrigamiMatmulSelector.dtype_to_str.get(b_dtype, b_dtype) self._out_dtype_str = OrigamiMatmulSelector.dtype_to_str.get( out_dtype, out_dtype ) # Save MX block size self._mx_block_size = mx_block_size ##### # Helper function to get bits for both float, int, and MX dtypes mx_types = ["f4"] def get_dtype_bits(dtype): # Handle MX types (string-based) if dtype in mx_types: return origami.datatype_to_bits(origami.string_to_datatype(dtype)) # Handle torch dtypes try: return torch.finfo(dtype).bits except TypeError: return torch.iinfo(dtype).bits self._a_dtype_bitsize = get_dtype_bits(a_dtype) self._b_dtype_bitsize = get_dtype_bits(b_dtype) self._out_dtype_bitsize = get_dtype_bits(out_dtype) # For matrix instruction latency lookup, use input dtype (not output dtype) # because the matrix instruction type is determined by input operand types # Example: FP8 inputs with BF16 output still uses FP8 matrix instructions # Set MI dtype - use string for MX types, otherwise lookup from dict if a_dtype in mx_types: self.mi_dtype = a_dtype else: input_dtype_for_mi = ( a_dtype if get_dtype_bits(a_dtype) <= get_dtype_bits(b_dtype) else b_dtype ) self.mi_dtype = OrigamiMatmulSelector.dtype_to_str.get( input_dtype_for_mi, OrigamiMatmulSelector.dtype_to_str.get(out_dtype) ) ##### # Get hardware info from Origami self._hardware = origami.get_hardware_for_device(device.index) # Detect architecture name for MI instruction selection. # Prefer origami's hardware_t.arch if available; fall back to # torch's gcnArchName property (strip suffix like ":sramecc+:xnack-"). if hasattr(self._hardware, 'arch') and hasattr(self._hardware.arch, 'name'): self._arch_name = self._hardware.arch.name else: import torch as _torch _gcn = getattr(_torch.cuda.get_device_properties(device), "gcnArchName", "") self._arch_name = _gcn.split(":")[0] if _gcn else "unknown" # The GPU-reported N_CU reflects any active CU mask. Save it # before overriding so Stream-K can size its grid to the real # number of schedulable CUs. self._active_cus = active_cus # When running under a CU mask (e.g. cu-sweep), the GPU reports a # reduced N_CU. Override with the real total so architecture # detection and config generation use the correct value. if total_cus is not None: self._hardware.N_CU = total_cus self._N_CU = self._hardware.N_CU self._ACTIVE_CU = active_cus if active_cus is not None else self._N_CU # Create list of Origami config_t objects from defaults. self._block_mn_range = [16, 32, 64, 128, 256] self._block_k_range = [16, 32, 64, 128, 256, 512] self._kernel_occupancy_range = [1] self._configs = self._generate_default_configs() # Create Origami problem_t based on problem metadata (needed for fallback) self._problem = self._make_problem() # Filter configs by Triton LDS capacity (async_copy + num_stages + padding). # Origami's check_lds_capacity uses raw tile size only; Triton allocates # num_stages buffers with padding for bank conflicts. # LDS issues only affect largest tiles; smaller configs should always pass. bytes_a = self._a_dtype_bitsize / 8 bytes_b = self._b_dtype_bitsize / 8 lds_cap = self._hardware.lds_capacity self._configs = [ c for c in self._configs if check_triton_lds_capacity( c.mt.m, c.mt.n, c.mt.k, bytes_a, bytes_b, lds_cap, self._num_stages ) ] if not self._configs: # Fallback: origami's raw check (no Triton padding/stages) is more permissive. # Used when Triton filter is overly conservative; smaller tiles should pass. self._configs = self._generate_default_configs() self._configs = [ c for c in self._configs if origami.check_lds_capacity( self._hardware, c.mt, self._problem.a_dtype, self._problem.b_dtype ) ] if not self._configs: # Should not happen on supported hardware (64KB+ LDS); small tiles always fit. raise RuntimeError( "No configs passed LDS checks; unexpected for supported hardware" ) # Run Origami solution selection self._result = origami.select_config( self._problem, self._hardware, self._configs ) # Heuristic to favor 256x256x64 tile when close~ if (check_triton_lds_capacity(256, 256, 64, bytes_a, bytes_b, lds_cap, self._num_stages) and ((self._result.config.mt.m == 256 and self._result.config.mt.n != 256) or (self._result.config.mt.m != 256 and self._result.config.mt.n == 256))): self._result.config.mt.m = 256 self._result.config.mt.n = 256 self._result.config.mt.k = 64 if streamk: self._grid = self._compute_sk_grid() else: self._grid = self._hardware.N_CU # Handle different origami API versions for workgroup mapping _wg_result = origami.select_workgroup_mapping( self._problem, self._hardware, self._result.config, self._grid ) if isinstance(_wg_result, tuple): # Older origami: returns (mode, xcc_mapping, mapping) or (xcc_mapping, mapping) if len(_wg_result) == 3: _, self._xcc_workgroup_mapping, self._workgroup_mapping = _wg_result else: self._xcc_workgroup_mapping, self._workgroup_mapping = _wg_result else: # origami >= 0.1.0: returns workgroup_mapping_t object self._xcc_workgroup_mapping = _wg_result.wgmxcc self._workgroup_mapping = _wg_result.wgm self._select_ws_params()
def _select_ws_params(self): """Select work-stealing parameters based on tile count. Empirically tuned on MI300X (8 XCDs, 304 CUs) via autotune sweeps across GEMM sizes 1K-16K. """ bm = self._result.config.mt.m bn = self._result.config.mt.n total_tiles = ((self._m + bm - 1) // bm) * ((self._n + bn - 1) // bn) tiles_m = (self._m + bm - 1) // bm if total_tiles <= 512: self.COUNTERS_PER_XCD = 8 elif total_tiles <= 1536: self.COUNTERS_PER_XCD = 4 elif total_tiles <= 2048: self.COUNTERS_PER_XCD = 2 else: self.COUNTERS_PER_XCD = 1 self._workgroup_mapping = min(8, tiles_m)
[docs] def hierarchical_split(self, num_xcds: int) -> tuple: """Compute optimal local/global tile split for hierarchical WS. Uses the full hardware CU count (not active CUs) so that the split is a topology-level constant, avoiding Triton recompilation when the active CU mask changes. Adaptive split based on tiles-per-CU density: - <=4 tiles/CU: 100% local (global counter overhead dominates) - >4 tiles/CU: local_frac decreases linearly, floor at 50% Returns (local_per_xcd, global_tiles). """ bm = self._result.config.mt.m bn = self._result.config.mt.n total_tiles = ((self._m + bm - 1) // bm) * ((self._n + bn - 1) // bn) hw_cus = self._hardware.NUM_XCD * self._hardware.CU_per_L2 tiles_per_cu = total_tiles / max(hw_cus, 1) local_frac = max(0.5, 1.0 - max(0.0, tiles_per_cu - 4.0) * 0.05) local_per_xcd = int(total_tiles * local_frac) // num_xcds local_per_xcd = max(local_per_xcd, 1) global_tiles = total_tiles - local_per_xcd * num_xcds return local_per_xcd, global_tiles
@property def block_m(self): return self._result.config.mt.m @property def block_n(self): return self._result.config.mt.n @property def block_k(self): return self._result.config.mt.k @property def group_m(self): return self._workgroup_mapping @property def num_sms(self): return self._xcc_workgroup_mapping @property def num_stages(self): return self._num_stages @property def waves_per_eu(self): return self._result.config.occupancy @property def even_k(self): return self._k % self.block_k == 0 @property def sk_grid(self): return self._grid def _compute_sk_grid(self): # Grid model constants for StreamK split_factors = [8, 6, 4, 3, 2, 1] tile_fractions = [0.0, 1.0 / 2.0, 1.0 / 8.0, 1.0 / 5.0, 1.0 / 4.0, 1.0 / 3.0] max_workspace = 128 * 1024 * 1024 M, N, K = self._m, self._n, self._k BLK_M, BLK_N, BLK_K = self.block_m, self.block_n, self.block_k cu_count = self._hardware.N_CU # Fallback if no better fractional split is found tiles = ceil(M / BLK_M) * ceil(N / BLK_N) sk_grid = tiles iters_per_tile = max(1, ceil(K / BLK_K)) # More tiles than CUs: try fractional splits to distribute work if tiles > cu_count: virt_cu_count = cu_count # if size_mapping.CUOccupancy > 1: # virt_cu_count *= size_mapping.CUOccupancy # Try these fractional denominators in order min_even_tiles = tiles / virt_cu_count for frac in tile_fractions: # Compute candidate grid with rounding frac_grid = int((tiles / (min_even_tiles + frac)) + 0.5) # Skip if this split leaves a remainder AND workspace is too large if ( tiles % frac_grid != 0 and self._partial_tile_size(frac_grid) > max_workspace ): continue # Accept the first grid no larger than the virtual CU count if frac_grid <= virt_cu_count: sk_grid = frac_grid break # Fewer tiles than CUs: split along k-dimension up to some factor elif tiles < cu_count: for factor in split_factors: split_grid = tiles * factor iters_per_cu = iters_per_tile // factor if split_grid <= cu_count and iters_per_cu >= 8: sk_grid = split_grid break # Final check: if the chosen grid leaves a remainder AND # workspace exceeds what the problem allows, fall back to no split if tiles % sk_grid != 0: sk_grid = tiles if tiles >= cu_count: last_wave_remainder = tiles % cu_count last_wave_occupancy = last_wave_remainder / cu_count # Really bad last wave, which would have originally been compensated for # by changing tile size, but triton tile sizes are limited if ( last_wave_remainder < 128 and last_wave_remainder > 0 and cu_count in [304, 80, 64] ): # gfx942 sk_grid = 256 if cu_count == 304 else 64 return sk_grid def _partial_tile_size(self, sk_grid: int) -> int: """ Python equivalent of ContractionSolution::partialTileSize. workspaceSizePerElemC = (element_size_out bits) / 8 → bytes per output element tileSize = BLK_M * BLK_N * workspaceSizePerElemC return tileSize * sk_grid """ # get the macro-tile dims you already compute BLK_M, BLK_N = self.block_m, self.block_n # bytes per C element bytes_per_elem = self._out_dtype_bitsize // 8 # size of one partial tile per WG tile_size = BLK_M * BLK_N * bytes_per_elem # scale by the number of partial‑tiles per WG return tile_size * sk_grid def _generate_default_configs(self): config_list = [] mi = self._infer_matrix_instruction_dimensions() for blk_m, blk_n, blk_k, occupancy in itertools.product( self._block_mn_range, self._block_mn_range, self._block_k_range, self._kernel_occupancy_range, ): # Create special dim3_t object for BLK_* sizes mt = origami.dim3_t(blk_m, blk_n, blk_k) # Create and set new config_t values new_config = origami.config_t() new_config.mt = mt new_config.mi = mi new_config.occupancy = occupancy if self.streamk: new_config.grid_selection = origami.grid_selection_t.k_split_aware else: new_config.grid_selection = origami.grid_selection_t.data_parallel config_list.append(new_config) return config_list def _make_problem(self) -> origami.problem_t: # Create special dim3_t object for problem sizes size = origami.dim3_t(self._m, self._n, self._k) # Convert torch dtypes to Origami dtypes based on problem metadata a_origami_dtype = origami.string_to_datatype(self._a_dtype_str) b_origami_dtype = origami.string_to_datatype(self._b_dtype_str) c_origami_dtype = origami.string_to_datatype(self._out_dtype_str) # Create and set new problem_t values problem = origami.problem_t() problem.size = size problem.batch = 1 problem.a_transpose = origami.transpose_t.T problem.b_transpose = origami.transpose_t.N problem.a_dtype = a_origami_dtype problem.b_dtype = b_origami_dtype problem.c_dtype = c_origami_dtype problem.d_dtype = c_origami_dtype problem.mi_dtype = c_origami_dtype problem.a_mx_block_size = self._mx_block_size problem.b_mx_block_size = self._mx_block_size return problem def _infer_matrix_instruction_dimensions(self): """ Infers the matrix instruction dimensions based on the hardware configuration and the sizes of the input data types. The input dtype sizes are retrieved from local object variables. Returns: origami.dim3_t: An Origami dimension trio containing the matrixinstruction dimensions [M, N, K]. Raises: ValueError: If the hardware architecture is unsupported or if the data type sizes are not compatible with the detected hardware. """ largest_bitsize = max(self._a_dtype_bitsize, self._b_dtype_bitsize) mi_dim = None # gfx950 if self._arch_name == "gfx950": # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 32) # F4F6F8 if largest_bitsize <= 8: if self._k % 256 == 0: self._block_k_range = self._block_k_range + [256] else: self._block_k_range = self._block_k_range + [128] self._block_mn_range = [32, 64, 128, 256] mi_dim = origami.dim3_t(16, 16, 128) # gfx942 (304 CUs full, 80 CUs partitioned, 64 CUs) if self._arch_name == "gfx942": # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) # F8 if largest_bitsize == 8: self._block_mn_range = self._block_mn_range + [512] self._block_k_range = self._block_k_range + [128, 256] mi_dim = origami.dim3_t(16, 16, 32) # F4F6 -> Unsupported on gfx942 if largest_bitsize < 8: raise ValueError("gfx942 doesn't support F4/F6") if self._hardware.N_CU == 228: # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) # F8 if largest_bitsize == 8: self._block_mn_range = self._block_mn_range + [512] self._block_k_range = self._block_k_range + [128, 256] mi_dim = origami.dim3_t(16, 16, 32) # F4F6 -> Unsupported on MI300A if largest_bitsize < 8: raise ValueError("MI300A doesn't support F4/F6") # gfx90a if self._arch_name == "gfx90a": # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) if largest_bitsize == 8: raise ValueError("MI200 doesn't support F8") if largest_bitsize < 8: raise ValueError("MI200 doesn't support F4/F6") # Architecture Detected is not valid if mi_dim == None: raise ValueError( f"No Valid Matrix Instruction for {self._a_dtype_bitsize}-bit/{self._b_dtype_bitsize}-bit dtypes " f"on hardware with N_CU={self._hardware.N_CU}" ) return mi_dim