Architecture & Compilation Pipeline Guide
FlyDSL project structure, compilation stages, key abstractions, and configuration.
Quick Reference
Component |
Description |
Key File |
|---|---|---|
FlyDSL |
Python DSL front-end for authoring GPU kernels |
|
FlyDSL Compiler |
|
|
FlyDSL Expr |
DSL expression ops (arith, vector, gpu, buffer, rocdl) |
|
Fly Dialect |
Flexible Layout IR — MLIR dialect with layout algebra |
|
MlirCompiler |
End-to-end MLIR pass pipeline (DSL → binary) |
|
JITCFunction |
MLIR ExecutionEngine wrapper for JIT execution |
|
1. Project Structure
FlyDSL/
├── include/flydsl/ # C++ dialect headers
│ └── Dialect/
│ ├── Fly/ # Fly layout dialect
│ │ ├── IR/
│ │ │ ├── FlyDialect.td # Dialect declaration (name = "fly")
│ │ │ ├── FlyOps.td # Layout ops (make_shape, crd2idx, composition, ...)
│ │ │ ├── FlyTypeDefs.td # Custom types (!fly.int_tuple, !fly.layout, ...)
│ │ │ ├── FlyAttrDefs.td # Attributes
│ │ │ └── FlyInterfaces.td # Op interfaces
│ │ └── Transforms/
│ │ ├── Passes.td # Pass declarations (fly-layout-lowering, etc.)
│ │ └── LayoutLowering.td # Layout lowering pass
│ └── FlyROCDL/ # FlyROCDL dialect (copy/MMA atoms)
│ └── IR/
│ ├── Dialect.td # FlyROCDL dialect declaration
│ ├── CopyAtom.td # Copy atom ops
│ └── MmaAtom.td # MMA atom ops
│
├── lib/ # C++ dialect implementation
│ ├── Dialect/Fly/ # Fly dialect ops, type inference, lowering
│ ├── Dialect/FlyROCDL/ # FlyROCDL dialect implementation
│ ├── Conversion/ # Dialect conversion passes
│ └── Transforms/ # Optimization passes
│
├── python/flydsl/ # Python DSL package
│ ├── __init__.py # Package version
│ ├── compiler/
│ │ ├── __init__.py # Public API: jit, kernel, from_dlpack
│ │ ├── jit_function.py # @jit decorator, MlirCompiler, JitCacheManager
│ │ ├── kernel_function.py # @kernel decorator, KernelFunction, KernelLauncher
│ │ ├── jit_executor.py # JITCFunction (ExecutionEngine wrapper)
│ │ ├── jit_argument.py # Argument conversion (Tensor, Stream, Int32)
│ │ ├── ast_rewriter.py # AST rewriting for Python control flow → MLIR
│ │ └── protocol.py # DslType / JitArgument protocols
│ ├── expr/
│ │ ├── __init__.py # Public expr API
│ │ ├── typing.py # Types (T.f32, Tensor, Stream, Constexpr)
│ │ ├── numeric.py # DSL numeric types (Float32, Int32, ...)
│ │ ├── primitive.py # Primitive operations (layout algebra, copy, gemm)
│ │ ├── derived.py # Derived types (CopyAtom, MmaAtom, TiledCopy)
│ │ ├── arith.py # Arithmetic dialect ops
│ │ ├── vector.py # Vector dialect ops
│ │ ├── gpu.py # GPU dialect ops (thread_idx, block_idx, barrier)
│ │ ├── buffer_ops.py # Buffer / memory operations
│ │ └── rocdl.py # ROCm-specific intrinsics
│ ├── lang/ir/
│ │ └── types.py # T / Types helper
│ ├── runtime/
│ │ └── device.py # get_rocm_arch() — GPU architecture detection
│ └── utils/
│ ├── env.py # EnvManager — typed environment config
│ ├── logger.py # Logging utilities
│ └── smem_allocator.py # SmemAllocator for LDS management
│
├── examples/ # Runnable examples
│ ├── 01-vectorAdd.py # Vector addition with layout algebra
│ └── 02-tiledCopy.py # Tiled copy with partitioned tensors
│
├── kernels/ # Pre-built GPU kernels
│ ├── preshuffle_gemm.py # GEMM with B-preshuffle (@flyc.kernel API)
│ ├── layernorm_kernel.py # LayerNorm
│ ├── rmsnorm_kernel.py # RMSNorm
│ ├── softmax_kernel.py # Softmax
│ ├── reduce.py # Warp/block reduction primitives
│ ├── mfma_epilogues.py # MFMA result writeback patterns
│ ├── mfma_preshuffle_pipeline.py # Preshuffle helpers for MFMA kernels
│ └── layout_utils.py # Layout computation utilities
│
├── tests/
│ ├── mlir/ # MLIR IR tests (no GPU required)
│ ├── pyir/ # Python IR tests (no GPU required)
│ ├── kernels/ # GPU kernel tests + benchmarks
│ ├── python/ # Python DSL tests
│ │ ├── examples/ # Example-based tests
│ │ ├── gpu/ # GPU tests
│ │ └── ir/ # IR generation tests
│ ├── conftest.py # Pytest fixtures
│ ├── test_common.py # Shared test utilities
│ └── utils.py # Compilation helpers
│
└── scripts/ # Build and test helpers
├── build.sh # Build FlyDSL (CMake + ninja)
├── build_llvm.sh # Build MLIR from ROCm llvm-project
├── run_tests.sh # Run GEMM test suite
├── run_benchmark.sh # Run benchmarks
└── dumpir.sh # Dump intermediate IR
2. Architecture
The user-facing API lives in python/flydsl/. Kernel authors use @flyc.jit and @flyc.kernel decorators with expression operations from flydsl.expr:
Traces Python functions via AST rewriting and execution
Generates Fly dialect ops + standard MLIR dialects (gpu, arith, scf, memref, vector, rocdl)
Compiles through the
MlirCompilerpass pipeline (Fly → ROCDL → LLVM → HSACO)Caches compiled kernels to disk for fast re-use
Executes via MLIR ExecutionEngine
The Fly dialect (include/flydsl/Dialect/Fly/) provides the MLIR-level layout algebra (composition, product, divide, coordinate mapping). Python DSL operations in flydsl.expr lower to Fly dialect ops during tracing, which are then compiled through the MlirCompiler pipeline.
3. Compilation Pipeline
3.1 High-Level Flow
Python Function (@flyc.kernel / @flyc.jit)
│
▼ AST Rewriting
Transformed Python Function
│
▼ Tracing (execution inside MLIR Context)
MLIR Module (gpu, arith, scf, memref dialects)
│
▼ MlirCompiler.compile()
┌────────────────────────────────────────────────┐
│ gpu-kernel-outlining │ Outline GPU kernels
│ fly-canonicalize │ FlyDSL-specific canonicalization
│ fly-layout-lowering │ Layout algebra lowering
│ convert-fly-to-rocdl │ Fly ops → ROCDL intrinsics
│ canonicalize │ Standard MLIR canonicalization
│ gpu.module(convert-scf-to-cf, │ SCF → ControlFlow
│ convert-gpu-to-rocdl{...}) │ GPU → ROCDL (inside gpu.module)
│ rocdl-attach-target{chip=gfxNNN} │ Attach ROCm target
│ convert-scf-to-cf │ Host-side SCF → CF
│ convert-cf-to-llvm │ CF → LLVM dialect
│ gpu-to-llvm │ GPU types → LLVM types
│ convert-arith-to-llvm │ Arith → LLVM
│ convert-func-to-llvm │ Func → LLVM
│ reconcile-unrealized-casts │ Clean up casts
│ gpu-module-to-binary{format=fatbin} │ Emit HSACO binary
└────────────────────────────────────────────────┘
│
▼
JITCFunction (ExecutionEngine)
3.2 Pipeline Stages in Detail
The pipeline is defined in MlirCompiler._pipeline_fragments():
Stage |
Pass |
Description |
|---|---|---|
1 |
|
Moves GPU kernel bodies into |
2 |
|
FlyDSL-specific canonicalization (custom pass). |
3 |
|
Lowers layout algebra operations to standard arithmetic. |
4 |
|
Converts FlyDSL ops to ROCDL intrinsics. |
5 |
|
Standard MLIR canonicalization (constant folding, etc.). |
6 |
|
Lowers SCF and GPU ops to ROCDL (inside |
7 |
|
Attaches |
8 |
|
Host-side SCF lowering. |
9 |
|
ControlFlow → LLVM dialect. |
10 |
|
GPU types/ops → LLVM dialect (host-side launch). |
11 |
|
Arithmetic → LLVM. |
12 |
|
Function → LLVM. |
13 |
|
Final cast cleanup. |
14 |
|
Compiles GPU module to HSACO binary (fatbin). |
3.3 JIT Compilation Flow
When a @flyc.jit function is called:
Cache check — look up by argument type signature (in-memory → disk)
AST rewriting —
ASTRewriter.transformconverts Pythonfor/ifto MLIRscf.for/scf.ifMLIR module creation — sets up
gpu.container_modulewith targetArgument conversion —
convert_to_jit_argumentsmaps Python args to IR typesFunction tracing — execute transformed function body to generate MLIR ops
GPU kernel emission —
@kernelcalls emitgpu.funcintogpu.modulePipeline compilation —
MlirCompiler.compile()runs the full pass pipelineExecution —
JITCFunctionwraps MLIR ExecutionEngine for invoking the compiled codeCache store — compiled function is serialized to disk for future runs
4. Key Abstractions
4.1 @flyc.jit — Host Launcher
Decorates a Python function as a JIT-compiled host launcher:
import flydsl.compiler as flyc
import flydsl.expr as fx
@flyc.jit
def launch(a: fx.Tensor, b: fx.Tensor, n: fx.Constexpr[int],
stream: fx.Stream = fx.Stream(None)):
my_kernel(a, b, n).launch(grid=(n // 256,), block=(256,), stream=stream)
Key behaviors:
First call triggers compilation; subsequent calls with the same type signature use cached binary
Constexpr[T]parameters become compile-time constants (affect cache key)Tensorparameters map to memref descriptors via DLPackStreamparameters pass CUDA/HIP stream to the GPU runtimeWhen called inside an existing MLIR context, acts as a normal function (composable)
4.2 @flyc.kernel — GPU Kernel
Decorates a Python function as a GPU kernel:
@flyc.kernel
def my_kernel(a: fx.Tensor, b: fx.Tensor, n: fx.Constexpr[int]):
tid = fx.gpu.thread_id("x")
bid = fx.gpu.block_id("x")
# ... kernel body ...
Key behaviors:
Can only be called inside a
@flyc.jitfunctionCalling returns a
KernelLauncher— you must call.launch()to emit the launch opSupports
Constexpr[T]for compile-time specializationEmits a
gpu.funcwithgpu.kernelattribute into thegpu.module
4.3 KernelLauncher
Returned by calling a @kernel function. Use .launch() to configure and emit the GPU launch:
launcher = my_kernel(a, b, 1024)
launcher.launch(
grid=(num_blocks, 1, 1),
block=(256, 1, 1),
smem=shared_mem_bytes,
stream=stream_value,
)
4.4 JITCFunction
Wraps MLIR’s ExecutionEngine for JIT execution:
Thread-safe with lazy engine initialization
Serializable (pickle) for disk caching
Supports packed calling convention via
ctypesProvides
.print_ir()for debugging compiled/original IR
4.5 DslType / JitArgument Protocols
Extensible type system for mapping Python values to MLIR:
# DslType protocol — for values used inside kernel/jit functions
class DslType(Protocol):
@classmethod
def __fly_construct__(cls, values: List[ir.Value]) -> "DslType": ...
def __fly_values__(self) -> List[ir.Value]: ...
# JitArgument protocol — for values passed at the host boundary
class JitArgument(Protocol):
def __fly_types__(self) -> List[ir.Type]: ...
def __fly_ptrs__(self) -> List[ctypes.c_void_p]: ...
Built-in types: Tensor, Stream, Int32, Constexpr[T]
Register custom types:
from flydsl.compiler import JitArgumentRegistry
@JitArgumentRegistry.register(MyPythonType, dsl_type=MyDslType)
class MyJitArg:
def __fly_types__(self): ...
def __fly_ptrs__(self): ...
4.6 ASTRewriter
Transforms Python control flow to MLIR ops at the AST level:
for i in range(n)→scf.forfor i in range_constexpr(n)→ compile-time unrolled loopif condition→scf.ifconst_expr(value)→ compile-time constant
5. Environment Variables
5.1 Compilation Options (FLYDSL_COMPILE_*)
Variable |
Default |
Description |
|---|---|---|
|
|
Optimization level (0–3) |
|
|
If |
|
auto-detect |
Override target GPU architecture (e.g., |
5.2 Debug Options (FLYDSL_DEBUG_*)
Variable |
Default |
Description |
|---|---|---|
|
|
Dump intermediate IR at each pipeline stage. |
|
|
Directory for IR dumps. |
|
|
Dump final AMD ISA assembly. |
|
|
Print AST diff during rewrite. |
|
|
Print origin IR before compilation. |
|
|
Print IR after each MLIR pass. |
|
|
Generate debug info in compiled code. |
|
|
Verify IR module. |
|
|
Logging level (DEBUG, INFO, WARNING, ERROR). |
5.3 Runtime Options (FLYDSL_RUNTIME_*)
Variable |
Default |
Description |
|---|---|---|
|
|
Directory for caching compiled kernels. |
|
|
Enable kernel caching. |
5.4 Architecture Detection Priority
get_rocm_arch() in runtime/device.py checks in order:
FLYDSL_GPU_ARCHenv varHSA_OVERRIDE_GFX_VERSIONenv var (supports9.4.2→gfx942format)rocm_agent_enumeratorsystem toolDefault:
gfx942
6. Target Hardware
Architecture |
GPU |
LDS per CU |
Notes |
|---|---|---|---|
|
MI300A / MI300X |
64 KB |
CDNA 3, primary development target |
|
MI350 |
160 KB |
CDNA 4, larger LDS |
|
MI250X |
64 KB |
CDNA 2 (verified platform) |
7. IR Dump Workflow
Enable with FLYDSL_DUMP_IR=1:
FLYDSL_DUMP_IR=1 FLYDSL_DUMP_DIR=./dumps python test_my_kernel.py
Produces numbered .mlir files:
dumps/my_func_name/
├── 00_original.mlir
├── 01_gpu-kernel-outlining.mlir
├── 02_fly-canonicalize.mlir
├── 03_fly-layout-lowering.mlir
├── 04_convert-fly-to-rocdl.mlir
├── 05_canonicalize.mlir
├── 06_convert-scf-to-cf.mlir
├── 07_rocdl-attach-target.mlir
├── 08_convert-scf-to-cf.mlir
├── 09_convert-cf-to-llvm.mlir
├── 10_gpu-to-llvm.mlir
├── 11_convert-arith-to-llvm.mlir
├── 12_convert-func-to-llvm.mlir
├── 13_reconcile-unrealized-casts.mlir
├── 14_gpu-module-to-binary.mlir
└── final_isa.s # AMD ISA assembly (best-effort)
8. Source Files
File |
Description |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Layout algebra primitives (make_shape, crd2idx, copy, gemm) |
|
Derived types ( |
|
DSL numeric types (Float32, Int32, …) |
|
|
|
|
|
Fly dialect op definitions |
|
Pass declarations (fly-layout-lowering, etc.) |