Kernel Authoring Guide
Writing GPU kernels with FlyDSL:
@flyc.jit,@flyc.kernel, expression API, launch configuration, shared memory, and synchronization.
API: This guide documents the
@flyc.kernel/@flyc.jitAPI fromflydsl.compilerandflydsl.expr(python/flydsl/).
Quick Reference
Concept |
API |
Description |
|---|---|---|
JIT host func |
|
Emit host-side launcher with JIT compilation |
GPU kernel |
|
Define GPU kernel function |
Launch |
|
Configure and emit GPU launch |
Thread ID |
|
Get thread index in workgroup |
Block ID |
|
Get block/workgroup index |
Block dim |
|
Get block dimension size |
Compile-time |
|
Compile-time constant parameter |
Tensor arg |
|
GPU tensor argument (via DLPack) |
Stream arg |
|
CUDA/HIP stream argument |
Barrier |
|
Workgroup synchronization |
Constants |
|
Create MLIR constant value |
Range loop |
|
Compile-time unrolled loop |
Buffer load |
|
AMD buffer load intrinsic |
1. Basic Kernel Pattern
1.1 @flyc.kernel + @flyc.jit
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import arith, gpu
@flyc.kernel
def vec_add_kernel(
A: fx.Tensor,
B: fx.Tensor,
C: fx.Tensor,
N: fx.Constexpr[int],
):
tid = gpu.thread_idx.x
bid = gpu.block_idx.x
idx = bid * 256 + tid
# ... kernel body using arith/vector/buffer ops ...
@flyc.jit
def vec_add(
A: fx.Tensor,
B: fx.Tensor,
C: fx.Tensor,
N: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None),
):
vec_add_kernel(A, B, C, N).launch(
grid=(N // 256,),
block=(256,),
stream=stream,
)
# Usage:
import torch
A = torch.randn(1024, device="cuda", dtype=torch.float32)
B = torch.randn(1024, device="cuda", dtype=torch.float32)
C = torch.empty(1024, device="cuda", dtype=torch.float32)
vec_add(A, B, C, 1024)
1.2 How It Works
@flyc.kernelwraps the function as aKernelFunction@flyc.jitwraps the function as aJitFunctionOn first call,
JitFunction.__call__triggers:AST rewriting (Python loops/ifs → MLIR scf ops)
MLIR module creation with
gpu.container_moduleTracing the jit function body to generate MLIR ops
Calling
vec_add_kernel(...)emits agpu.funcingpu.module.launch()emitsgpu.launch_funcMlirCompiler.compile()runs the full pass pipelineJITCFunctionwraps the resulting ExecutionEngine
Subsequent calls with the same type signature use the cached binary
2. Parameter Types
2.1 fx.Tensor
Maps a PyTorch tensor to an MLIR memref descriptor via DLPack:
@flyc.kernel
def my_kernel(input: fx.Tensor, output: fx.Tensor):
# input and output are Tensor wrappers around ir.Value (memref)
...
At the host boundary, torch.Tensor is automatically converted via TensorAdaptor.
2.2 fx.Constexpr[T]
Compile-time constant. Value is embedded directly in the generated IR:
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int], dtype: fx.Constexpr[str]):
for i in range_constexpr(N // 64): # unrolled at compile time
...
Different Constexpr values produce different compiled kernels (separate cache entries).
2.3 fx.Int32
Runtime integer parameter (passed as i32):
@flyc.jit
def launch(data: fx.Tensor, size: fx.Int32, stream: fx.Stream = fx.Stream(None)):
...
Python int values are automatically converted to Int32 via the JitArgumentRegistry.
2.4 fx.Stream
CUDA/HIP stream for asynchronous kernel launch:
@flyc.jit
def launch(data: fx.Tensor, stream: fx.Stream = fx.Stream(None)):
my_kernel(data).launch(grid=(1,), block=(256,), stream=stream)
# Launch on specific stream:
stream = torch.cuda.Stream()
launch(data, stream=fx.Stream(stream))
2.5 Custom Argument Types
Register new Python types for the JIT boundary:
from flydsl.compiler import JitArgumentRegistry
@JitArgumentRegistry.register(MyCustomType, dsl_type=MyDslType)
class MyCustomAdaptor:
def __init__(self, value: MyCustomType):
self.value = value
def __fly_types__(self):
return [...] # MLIR types for this argument
def __fly_ptrs__(self):
return [...] # ctypes pointers for invocation
3. Thread / Block Hierarchy
from flydsl.expr import gpu
# Thread index within workgroup (returns Int32)
tid_x = gpu.thread_idx.x
tid_y = gpu.thread_idx.y
tid_z = gpu.thread_idx.z
# Block (workgroup) index within grid
bid_x = gpu.block_idx.x
bid_y = gpu.block_idx.y
# Block dimensions
bdim_x = gpu.block_dim.x
# Grid dimensions
gdim_x = gpu.grid_dim.x
# Low-level (returns raw ir.Value)
raw_tid = gpu.thread_id("x")
raw_bid = gpu.block_id("x")
4. Expression API (flydsl.expr)
4.1 Arithmetic (fx.arith)
from flydsl.expr import arith
from flydsl.expr.typing import T
# Constants
c42 = arith.constant(42, index=True) # index type
c3_14 = arith.constant(3.14, T.f32) # f32 type
# Arithmetic (operator overloading via ArithValue)
result = a + b
result = a * 2
result = a // 4
result = a % 16
# Cast
idx = arith.index_cast(T.index, int_val)
# Select
result = arith.select(cond, true_val, false_val)
# Bitwise
result = arith.andi(a, b)
result = arith.xori(a, b)
result = arith.shli(a, b)
4.2 Vector Operations (fx.vector)
from flydsl.expr import vector
# Build vector from elements
vec = vector.from_elements(vec_type, [a, b, c, d])
# Vector store to memref
vector.store(vec, memref, [idx])
# Extract/insert
elem = vector.extractelement(vec, idx)
vec2 = vector.insertelement(vec, elem, idx)
4.3 Buffer Operations (fx.buffer_ops)
AMD buffer load/store intrinsics for efficient global memory access:
from flydsl.expr import buffer_ops
# Create buffer resource descriptor from memref
rsrc = buffer_ops.create_buffer_resource(memref_value)
# Buffer load (vectorized)
data = buffer_ops.buffer_load(rsrc, byte_offset, vec_width=4)
# Buffer store
buffer_ops.buffer_store(data, rsrc, byte_offset)
4.4 ROCm Intrinsics (fx.rocdl)
from flydsl.expr import rocdl
# MFMA instructions
result = rocdl.mfma_f32_16x16x16_f16(a, b, acc)
result = rocdl.mfma_f32_16x16x32_fp8(a, b, acc)
result = rocdl.mfma_i32_16x16x32i8(a, b, acc)
# Warp shuffle
val = rocdl.ds_bpermute(idx, src)
# LDS operations
rocdl.ds_write_b128(lds_ptr, offset, data)
data = rocdl.ds_read_b128(lds_ptr, offset)
4.5 GPU Operations (fx.gpu)
from flydsl.expr import gpu
# Barrier (workgroup synchronization)
gpu.barrier()
# Shared memory address space attribute
addrspace = gpu.smem_space()
addrspace_int = gpu.smem_space(int=True)
5. Control Flow
5.1 Python Loops → MLIR SCF
The ASTRewriter automatically transforms Python for loops:
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int]):
# Compile-time unrolled loop
for i in range_constexpr(N):
# This loop is fully unrolled in the generated IR
...
# Runtime loop (lowered to scf.for)
for i in range(runtime_value):
...
5.2 const_expr()
Mark a value as compile-time constant:
from flydsl.expr import const_expr
@flyc.kernel
def my_kernel(data: fx.Tensor, N: fx.Constexpr[int]):
tile_size = const_expr(N // 4)
for i in range_constexpr(tile_size):
...
6. Shared Memory (LDS)
6.1 SmemAllocator
from flydsl.utils.smem_allocator import SmemAllocator
from flydsl.expr.typing import T
# Create allocator for target architecture
allocator = SmemAllocator(None, arch="gfx942", global_sym_name="smem0")
# Allocate typed arrays
lds_a = allocator.allocate_array(T.f16, 8192)
lds_b = allocator.allocate_array(T.f16, 8192)
# Inside kernel: get base pointer and typed views
lds_base = allocator.get_base()
lds_a_ptr = lds_a(lds_base) # SmemPtr
lds_b_ptr = lds_b(lds_base) # SmemPtr
# Load/store through SmemPtr
val = lds_a_ptr.load([idx])
lds_b_ptr.store(val, [idx])
6.2 Finalizing LDS Allocation
For @flyc.kernel style kernels, emit memref.global in the GPU module:
comp_ctx = CompilationContext.get_current()
with ir.InsertionPoint(comp_ctx.gpu_module_body):
allocator.finalize()
6.3 LDS Capacity
Architecture |
LDS per CU |
|---|---|
|
64 KB |
|
160 KB |
7. Launch Configuration
7.1 KernelLauncher.launch()
@flyc.jit
def launch(data: fx.Tensor, stream: fx.Stream = fx.Stream(None)):
my_kernel(data).launch(
grid=(num_blocks_x, num_blocks_y, num_blocks_z),
block=(threads_x, threads_y, threads_z),
smem=shared_mem_bytes, # dynamic shared memory
stream=stream, # CUDA/HIP stream
)
Grid and block dimensions accept:
int— static valueir.Value— dynamic MLIR valueTuple of 1–3 values — missing dimensions default to 1
7.2 Dynamic Grid/Block Dimensions
@flyc.jit
def launch(data: fx.Tensor, M: fx.Int32, stream: fx.Stream = fx.Stream(None)):
grid_x = M // 256
my_kernel(data, M).launch(
grid=(grid_x, 1, 1),
block=(256, 1, 1),
stream=stream,
)
8. Synchronization
from flydsl.expr import gpu
# Workgroup barrier (s_barrier)
gpu.barrier()
9. Compilation & Caching
9.1 Automatic Caching
JIT-compiled functions are cached automatically:
In-memory cache — keyed by argument type signature
Disk cache — stored in
~/.flydsl/cache/(configurable viaFLYDSL_RUNTIME_CACHE_DIR)Cache key includes: source code hash, dependency sources, closure values, FlyDSL version, LLVM version
9.2 Cache Invalidation
Cache is invalidated when:
Source code of the function or its dependencies changes
Argument types change (different tensor shapes/dtypes)
Constexprvalues changeFlyDSL or LLVM version changes
9.3 Disabling Cache
FLYDSL_RUNTIME_ENABLE_CACHE=0 python my_script.py
9.4 Compile-Only Mode
COMPILE_ONLY=1 python my_script.py
10. Debugging
10.1 Dumping IR
FLYDSL_DUMP_IR=1 FLYDSL_DUMP_DIR=./my_dumps python my_script.py
10.2 Printing IR
# After compilation, access IR from the compiled function:
result = launch(A, B, C, 1024)
# Or use JITCFunction directly:
compiled_func.print_ir() # compiled MLIR IR
compiled_func.print_ir(compiled=False) # original IR before passes
10.3 AST Diff
FLYDSL_DEBUG_AST_DIFF=1 python my_script.py
Shows the diff between original and rewritten AST for debugging control flow transformations.
11. Complete Example: Preshuffle GEMM
From kernels/preshuffle_gemm.py:
import flydsl.compiler as flyc
import flydsl.expr as fx
from flydsl.expr import arith, vector, gpu, buffer_ops, rocdl, range_constexpr
from flydsl.expr.typing import T
from flydsl.utils.smem_allocator import SmemAllocator
def compile_preshuffle_gemm_a8(*, M, N, K, tile_m, tile_n, tile_k,
in_dtype="fp8", lds_stage=2, ...):
allocator = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0")
lds_a = allocator.allocate_array(T.i8, tile_m * tile_k)
# ... more allocations ...
@flyc.kernel
def gemm_kernel(
arg_c: fx.Tensor, arg_a: fx.Tensor, arg_b: fx.Tensor,
arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor,
m_in: fx.Int32, n_in: fx.Int32,
):
tid = gpu.thread_idx.x
bid = gpu.block_idx.x
# ... complex GEMM implementation using MFMA, LDS, tiling ...
@flyc.jit
def launch_fn(
arg_c: fx.Tensor, arg_a: fx.Tensor, arg_b: fx.Tensor,
arg_scale_a: fx.Tensor, arg_scale_b: fx.Tensor,
M_val: fx.Int32, N_val: fx.Int32,
stream: fx.Stream = fx.Stream(None),
):
gemm_kernel(arg_c, arg_a, arg_b, arg_scale_a, arg_scale_b,
M_val, N_val).launch(
grid=(grid_x, grid_y), block=(256,),
smem=smem_bytes, stream=stream,
)
return launch_fn
12. Decision Tree
Writing a new kernel?
│
├── Simple element-wise?
│ ├── Use @flyc.kernel + @flyc.jit
│ ├── fx.gpu.thread_idx.x for thread indexing
│ └── See tests/kernels/test_vec_add.py
│
├── Reduction (norm, softmax)?
│ ├── Use warp_reduce / block_reduce from kernels/reduce.py
│ └── See kernels/layernorm_kernel.py, kernels/softmax_kernel.py
│
├── Matrix multiply (GEMM)?
│ ├── Use @flyc.kernel + SmemAllocator + MFMA
│ ├── B-preshuffle layout from mfma_preshuffle_pipeline.py
│ └── See kernels/preshuffle_gemm.py
│
├── Need shared memory?
│ ├── Use SmemAllocator with target arch
│ ├── Call finalize() in GPU module body
│ └── Call get_base() inside @kernel
│
└── Need compile-time specialization?
├── Use Constexpr[T] parameters
└── Use range_constexpr() for unrolled loops
13. Source Files
File |
Description |
|---|---|
|
Public API: |
|
|
|
|
|
|
|
|
|
|
|
|
|
Arithmetic operations |
|
Vector dialect operations |
|
GPU operations (thread_id, barrier, …) |
|
AMD buffer load/store operations |
|
ROCm dialect intrinsics |
|
Layout algebra primitives (make_shape, crd2idx, etc.) |
|
|
|
Preshuffle GEMM kernel example |
|
Warp/block reduction primitives |
|
Vector add kernel test |
|
Preshuffle GEMM test |