Source code for tritonblas.origami

import itertools
import torch
import origami
import math
from math import ceil


[docs] class OrigamiMatmulSelector: # https://docs.pytorch.org/docs/stable/tensors.html dtype_to_str = { torch.float32: "f32", torch.complex64: "c32", torch.complex128: "c64", torch.float64: "f64", torch.float16: "f16", torch.int32: "i32", torch.bfloat16: "bf16", torch.int8: "i8", torch.float8_e5m2: "f8", torch.float8_e4m3fn: "f8", } # Add FP8 FNUZ variants if available (for non-gfx950 architectures) if hasattr(torch, "float8_e5m2fnuz"): dtype_to_str[torch.float8_e5m2fnuz] = "f8" if hasattr(torch, "float8_e4m3fnuz"): dtype_to_str[torch.float8_e4m3fnuz] = "f8"
[docs] def __init__( self, m: int, n: int, k: int, a_dtype: torch.dtype, b_dtype: torch.dtype, out_dtype: torch.dtype, device: torch.device, mx_block_size=0, streamk=False, ): # Save tensor sizes self._m = m self._n = n self._k = k self.streamk = streamk # Save tensor dtypes as strings self._a_dtype_str = OrigamiMatmulSelector.dtype_to_str.get(a_dtype, a_dtype) self._b_dtype_str = OrigamiMatmulSelector.dtype_to_str.get(b_dtype, b_dtype) self._out_dtype_str = OrigamiMatmulSelector.dtype_to_str.get( out_dtype, out_dtype ) # Save MX block size self._mx_block_size = mx_block_size ##### # Helper function to get bits for both float, int, and MX dtypes mx_types = ["f4"] def get_dtype_bits(dtype): # Handle MX types (string-based) if dtype in mx_types: return origami.datatype_to_bits(origami.string_to_datatype(dtype)) # Handle torch dtypes try: return torch.finfo(dtype).bits except TypeError: return torch.iinfo(dtype).bits self._a_dtype_bitsize = get_dtype_bits(a_dtype) self._b_dtype_bitsize = get_dtype_bits(a_dtype) self._out_dtype_bitsize = get_dtype_bits(a_dtype) # For matrix instruction latency lookup, use input dtype (not output dtype) # because the matrix instruction type is determined by input operand types # Example: FP8 inputs with BF16 output still uses FP8 matrix instructions # Set MI dtype - use string for MX types, otherwise lookup from dict if a_dtype in mx_types: self.mi_dtype = a_dtype else: input_dtype_for_mi = ( a_dtype if get_dtype_bits(a_dtype) <= get_dtype_bits(b_dtype) else b_dtype ) self.mi_dtype = OrigamiMatmulSelector.dtype_to_str.get( input_dtype_for_mi, OrigamiMatmulSelector.dtype_to_str.get(out_dtype) ) ##### # Get hardware info from Origami self._hardware = origami.get_hardware_for_device(device.index) self._N_CU = self._hardware.N_CU # Create list of Origami config_t objects from defaults. self._block_mn_range = [16, 32, 64, 128, 256] self._block_k_range = [16, 32, 64, 128, 256, 512] self._kernel_occupancy_range = [1] self._configs = self._generate_default_configs() # Create Origami problem_t based on problem metadata self._problem = self._make_problem() # Run Origami solution selection self._result = origami.select_config( self._problem, self._hardware, self._configs ) if streamk: self._grid = self._compute_sk_grid() else: self._grid = self._hardware.N_CU # Try both workgroup mapping modes for compatibility with Origami Versions try: _mapping_mode, self._xcc_workgroup_mapping, self._workgroup_mapping = ( origami.select_workgroup_mapping( self._problem, self._hardware, self._result.config, self._grid ) ) except ValueError: self._xcc_workgroup_mapping, self._workgroup_mapping = ( origami.select_workgroup_mapping( self._problem, self._hardware, self._result.config, self._grid ) )
@property def block_m(self): return self._result.config.mt.m @property def block_n(self): return self._result.config.mt.n @property def block_k(self): return self._result.config.mt.k @property def group_m(self): return self._workgroup_mapping @property def num_sms(self): return self._xcc_workgroup_mapping @property def waves_per_eu(self): return self._result.config.occupancy @property def even_k(self): return math.gcd(self._k, self.block_k) == self.block_k @property def sk_grid(self): return self._grid def _compute_sk_grid(self): # Grid model constants for StreamK split_factors = [8, 6, 4, 3, 2, 1] tile_fractions = [0.0, 1.0 / 2.0, 1.0 / 8.0, 1.0 / 5.0, 1.0 / 4.0, 1.0 / 3.0] max_workspace = 128 * 1024 * 1024 M, N, K = self._m, self._n, self._k BLK_M, BLK_N, BLK_K = self.block_m, self.block_n, self.block_k cu_count = self._hardware.N_CU # Fallback if no better fractional split is found tiles = ceil(M / BLK_M) * ceil(N / BLK_N) sk_grid = tiles iters_per_tile = max(1, ceil(K / BLK_K)) # More tiles than CUs: try fractional splits to distribute work if tiles > cu_count: virt_cu_count = cu_count # if size_mapping.CUOccupancy > 1: # virt_cu_count *= size_mapping.CUOccupancy # Try these fractional denominators in order min_even_tiles = tiles / virt_cu_count for frac in tile_fractions: # Compute candidate grid with rounding frac_grid = int((tiles / (min_even_tiles + frac)) + 0.5) # Skip if this split leaves a remainder AND workspace is too large if ( tiles % frac_grid != 0 and self._partial_tile_size(frac_grid) > max_workspace ): continue # Accept the first grid no larger than the virtual CU count if frac_grid <= virt_cu_count: sk_grid = frac_grid break # Fewer tiles than CUs: split along k-dimension up to some factor elif tiles < cu_count: for factor in split_factors: split_grid = tiles * factor iters_per_cu = iters_per_tile // factor if split_grid <= cu_count and iters_per_cu >= 8: sk_grid = split_grid break # Final check: if the chosen grid leaves a remainder AND # workspace exceeds what the problem allows, fall back to no split if tiles % sk_grid != 0: sk_grid = tiles if tiles >= cu_count: last_wave_remainder = tiles % cu_count last_wave_occupancy = last_wave_remainder / cu_count # Really bad last wave, which would have originally been compensated for # by changing tile size, but triton tile sizes are limited if ( last_wave_remainder < 128 and last_wave_remainder > 0 and cu_count in [304, 80, 64] ): # gfx942 sk_grid = 256 if cu_count == 304 else 64 return sk_grid def _partial_tile_size(self, sk_grid: int) -> int: """ Python equivalent of ContractionSolution::partialTileSize. workspaceSizePerElemC = (element_size_out bits) / 8 → bytes per output element tileSize = BLK_M * BLK_N * workspaceSizePerElemC return tileSize * sk_grid """ # get the macro-tile dims you already compute BLK_M, BLK_N = self.block_m, self.block_n # bytes per C element bytes_per_elem = self._out_dtype_bitsize // 8 # size of one partial tile per WG tile_size = BLK_M * BLK_N * bytes_per_elem # scale by the number of partial‑tiles per WG return tile_size * sk_grid def _generate_default_configs(self): config_list = [] mi = self._infer_matrix_instruction_dimensions() for blk_m, blk_n, blk_k, occupancy in itertools.product( self._block_mn_range, self._block_mn_range, self._block_k_range, self._kernel_occupancy_range, ): # Create special dim3_t object for BLK_* sizes mt = origami.dim3_t(blk_m, blk_n, blk_k) # Create and set new config_t values new_config = origami.config_t() new_config.mt = mt new_config.mi = mi new_config.occupancy = occupancy if self.streamk: new_config.grid_selection = origami.grid_selection_t.k_split_aware else: new_config.grid_selection = origami.grid_selection_t.data_parallel config_list.append(new_config) return config_list def _make_problem(self) -> origami.problem_t: # Create special dim3_t object for problem sizes size = origami.dim3_t(self._m, self._n, self._k) # Convert torch dtypes to Origami dtypes based on problem metadata a_origami_dtype = origami.string_to_datatype(self._a_dtype_str) b_origami_dtype = origami.string_to_datatype(self._b_dtype_str) c_origami_dtype = origami.string_to_datatype(self._out_dtype_str) # Create and set new problem_t values problem = origami.problem_t() problem.size = size problem.batch = 1 problem.a_transpose = origami.transpose_t.T problem.b_transpose = origami.transpose_t.N problem.a_dtype = a_origami_dtype problem.b_dtype = b_origami_dtype problem.c_dtype = c_origami_dtype problem.d_dtype = c_origami_dtype problem.mi_dtype = c_origami_dtype problem.a_mx_block_size = self._mx_block_size problem.b_mx_block_size = self._mx_block_size return problem def _infer_matrix_instruction_dimensions(self): """ Infers the matrix instruction dimensions based on the hardware configuration and the sizes of the input data types. The input dtype sizes are retrieved from local object variables. Returns: origami.dim3_t: An Origami dimension trio containing the matrixinstruction dimensions [M, N, K]. Raises: ValueError: If the hardware architecture is unsupported or if the data type sizes are not compatible with the detected hardware. """ largest_bitsize = max(self._a_dtype_bitsize, self._b_dtype_bitsize) mi_dim = None # gfx950 if self._hardware.N_CU == 256: # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 32) # F4F6F8 if largest_bitsize <= 8: if self._k % 256 == 0: self._block_k_range = self._block_k_range + [256] else: self._block_k_range = self._block_k_range + [128] self._block_mn_range = [32, 64, 128, 256] mi_dim = origami.dim3_t(16, 16, 128) # gfx942 (304 CUs full, 80 CUs partitioned, 64 CUs) is_gfx942 = self._hardware.N_CU in [304, 80, 64] if is_gfx942: # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) # F8 if largest_bitsize == 8: self._block_mn_range = self._block_mn_range + [512] self._block_k_range = self._block_k_range + [128, 256] mi_dim = origami.dim3_t(16, 16, 32) # F4F6 -> Unsupported on gfx942 if largest_bitsize < 8: raise ValueError("gfx942 doesn't support F4/F6") if self._hardware.N_CU == 228: # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) # F8 if largest_bitsize == 8: self._block_mn_range = self._block_mn_range + [512] self._block_k_range = self._block_k_range + [128, 256] mi_dim = origami.dim3_t(16, 16, 32) # F4F6 -> Unsupported on MI300A if largest_bitsize < 8: raise ValueError("MI300A doesn't support F4/F6") # gfx90a if self._hardware.N_CU == 104: # FP32 if largest_bitsize == 32: mi_dim = origami.dim3_t(16, 16, 4) # FP16/BF16 if largest_bitsize == 16: mi_dim = origami.dim3_t(16, 16, 16) if largest_bitsize == 8: raise ValueError("MI200 doesn't support F8") if largest_bitsize < 8: raise ValueError("MI200 doesn't support F4/F6") # Architecture Detected is not valid if mi_dim == None: raise ValueError( f"No Valid Matrix Instruction integrated for {element_size_A}-bit or {element_size_B}-bit datatypes" ) return mi_dim