Pre-built Kernel Library Guide
Available FlyDSL kernels: Normalization, Softmax, GEMM — configuration, data types, pipelines, and shared utilities.
Quick Reference
Kernel |
Builder Function |
API Style |
Dtypes |
Key Feature |
|---|---|---|---|---|
LayerNorm |
|
Legacy (pending migration) |
f32, f16, bf16 |
Two-pass vectorized normalization |
RMSNorm |
|
Legacy (pending migration) |
f32, f16, bf16 |
LDS-cached 3-pass pipeline |
Softmax |
|
Legacy (pending migration) |
f32, f16, bf16 |
Online softmax, adaptive block size |
GEMM |
|
New ( |
fp8, int8, int4, fp16, bf16, fp4 |
Preshuffle B, ping-pong LDS, MFMA 16x16 |
Note on API styles: All kernels use the
@flyc.kernel/@flyc.jitAPI fromflydsl.compilerandflydsl.expr(python/flydsl/).
1. Normalization Kernels
1.1 LayerNorm (kernels/layernorm_kernel.py)
Computes LayerNorm(x) = (x - mean) / sqrt(var + eps) * gamma + beta for each row.
Builder:
from kernels.layernorm_kernel import build_layernorm_module
executor = build_layernorm_module(M=32768, N=8192, dtype_str="bf16")
Configuration Constants:
Constant |
Value |
Description |
|---|---|---|
|
256 |
Threads per block |
|
64 |
AMD wavefront size |
|
8 |
Vector load/store width |
|
16 |
Alignment for vector ops (bytes) |
|
1e-5 |
Numerical stability epsilon |
|
True |
Non-temporal stores for output |
Algorithm:
Two-pass normalization: Pass 1 computes mean and variance, Pass 2 applies affine transform
Fast path: When
N == BLOCK_THREADS * VEC_WIDTH * 4(e.g., N=8192), uses fully register-resident computation with no scalar tailGeneric path: Handles arbitrary N with vector body + scalar tail
bf16 handling: Software round-to-nearest-even (RNE) pack on gfx942; hardware
cvt_pk_bf16_f32on gfx950+Warp reduction: XOR-shuffle-based intra-wave reduction (shifts: 32, 16, 8, 4, 2, 1), then LDS-based cross-wave synchronization
Kernel signature (using @flyc.kernel API):
GPU_MODULE_NAME = "layernorm_module"
@kernel
layernorm_kernel(self, Input, Gamma, Beta, Output, m_in)
@jit
__call__(self, Input, Gamma, Beta, Output, m_in)
1.2 RMSNorm (kernels/rmsnorm_kernel.py)
Computes RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gamma.
Builder:
from kernels.rmsnorm_kernel import build_rmsnorm_module
executor = build_rmsnorm_module(M=32768, N=8192, dtype_str="bf16")
Configuration Constants: Same as LayerNorm (BLOCK_THREADS=256, VEC_WIDTH=8, etc.)
Algorithm (3-pass with LDS caching):
Pass 0: Global → LDS row cache (one-pass global read, vectorized)
Pass 1: Sum-of-squares computation from LDS row cache
Pass 2: Normalize + gamma multiply + store with software pipeline for Gamma prefetch
Kernel signature:
GPU_MODULE_NAME = "rmsnorm_module"
@kernel
rmsnorm_kernel(self, Input, Gamma, Output, m_in)
2. Softmax Kernel
2.1 Softmax (kernels/softmax_kernel.py)
Computes row-wise softmax: softmax(x)_i = exp(x_i - max(x)) / sum(exp(x - max(x))).
Builder:
from kernels.softmax_kernel import build_softmax_module
executor = build_softmax_module(M=32768, N=8192, dtype_str="bf16")
Configuration:
Parameter |
Value |
Description |
|---|---|---|
|
|
Adaptive block size |
|
8 |
Vector load/store width |
|
64 |
AMD wavefront size |
Algorithm (6 stages):
Load Data: Vectorized global loads into register buffer with validity masks
Local Max: Per-thread vector reduction (
maxnumf)Global Max: Block-wide shuffle reduction (intra-wave XOR → wave0 finalize via LDS)
Local Exp + Sum:
exp2(x * log2(e))approximation, accumulate partial sumsGlobal Sum: Block-wide reduction for sum
Normalize + Store: Divide by sum, convert to output dtype, vectorized store
Kernel signature:
GPU_MODULE_NAME = f"softmax_{dtype_str}"
@kernel
softmax_kernel(self, A, C, m_in)
3. GEMM Kernel
3.1 Preshuffle GEMM (kernels/preshuffle_gemm.py)
MFMA 16x16-based GEMM with B-matrix preshuffle layout: C[M,N] = A[M,K] @ B[N,K]^T.
Uses the new @flyc.kernel / @flyc.jit API.
Builder:
from kernels.preshuffle_gemm import compile_preshuffle_gemm_a8
launch_fn = compile_preshuffle_gemm_a8(
M=16, N=5120, K=8192,
tile_m=16, tile_n=128, tile_k=256,
in_dtype="fp8",
lds_stage=2,
use_cshuffle_epilog=False,
)
Returns a @flyc.jit-decorated function that auto-compiles on first call.
Parameters:
Parameter |
Type |
Description |
|---|---|---|
|
int |
GEMM dimensions: A[M,K], B[N,K], C[M,N]. M and N can be 0 (dynamic). |
|
int |
Block tile sizes |
|
str |
|
|
int |
|
|
bool |
CK-style LDS CShuffle epilogue |
|
int |
Occupancy hint (None = default, 1-4 = limit occupancy) |
|
bool |
Use async DMA for A tile global-to-LDS transfer |
Key constraints:
tile_k * elem_bytesmust be divisible by 64 (K64-byte micro-step)INT4 is W4A8: A is int8, B is packed int4 (2 values/byte), unpacked to int8 in-kernel
Pipeline details:
lds_stage=2 (ping-pong): Two LDS buffers for A tiles. Cross-tile A0 prefetch overlaps VMEM with LDS reads
lds_stage=1 (single): CK-style intrawave schedule with single LDS buffer
K64-byte micro-step: Each step issues 2x K32 MFMA operations
XOR16 swizzle: Byte-level swizzle on LDS to avoid bank conflicts
B-preshuffle: Shape (N0, K0, KLane, NLane, KPackBytes) = (N/16, K/64, 4, 16, kpack_bytes)
CShuffle epilogue: Write C tile to LDS in row-major, remap threads for half2 packing via
ds_bpermute
Launch function signature:
launch_fn(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b, M_val, N_val, stream)
Where:
arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b: PyTorch tensors (auto-converted to memref)M_val, N_val: Python int (auto-converted to Int32)stream:fx.Stream(default stream if omitted)
5. Kernel API Comparison
New API (GEMM)
Used by preshuffle_gemm.py:
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import arith, gpu, buffer_ops, rocdl
@flyc.kernel
def gemm_kernel(arg_c: fx.Tensor, arg_a: fx.Tensor, ...):
tid = gpu.thread_idx.x
# ... uses fx.*, arith.*, buffer_ops.*, rocdl.* ...
@flyc.jit
def launch_fn(arg_c: fx.Tensor, ..., stream: fx.Stream = fx.Stream(None)):
gemm_kernel(arg_c, ...).launch(grid=..., block=..., stream=stream)
6. Kernel Decision Tree
What operation do you need?
│
├── Normalization
│ ├── Need bias (beta) term? → LayerNorm (layernorm_kernel.py)
│ └── No bias term? → RMSNorm (rmsnorm_kernel.py)
│
├── Softmax
│ └── Row-wise softmax → Softmax (softmax_kernel.py)
│
├── Matrix Multiply (GEMM)
│ ├── Standard GEMM (uniform precision)
│ │ ├── FP8 / INT8 / INT4(W4A8) / FP16 / BF16 / FP4
│ │ └── → compile_preshuffle_gemm_a8()
│ │
│ └── Uses new @flyc.kernel API
│ └── See kernels/preshuffle_gemm.py
│
└── Building blocks
├── Warp/block reduction → reduce.py
├── MFMA epilogue selection → mfma_epilogues.py
├── Preshuffle data movement → mfma_preshuffle_pipeline.py
└── Static layout helpers → layout_utils.py
7. Source Files
File |
Description |
|---|---|
|
Package marker |
|
LayerNorm builder ( |
|
RMSNorm builder ( |
|
Softmax builder ( |
|
Preshuffle GEMM builder (new |
|
Shared warp/block reduction helpers |
|
MFMA epilogue strategies (default, CShuffle) |
|
Preshuffle data movement and layout utilities |
|
Pure-arith layout helpers ( |
8. Test Files
File |
Tests |
|---|---|
|
LayerNorm correctness + perf |
|
RMSNorm correctness + perf |
|
Softmax correctness + perf |
|
GEMM fp8/int8/int4/bf16/fp4 correctness + perf |
|
Shared benchmark infrastructure |