ATOM Compilation & CUDA Graphs Guide
Quick Reference
Concept
Key Class / Enum
Import
Compilation Levels
CompilationLevel
from atom.config import CompilationLevelCompilation Config
CompilationConfig
from atom.config import CompilationConfigCUDA Graph Modes
CUDAGraphMode
from atom.config import CUDAGraphModeCUDA Graph Wrapper
CUDAGraphWrapper
from atom.utils.cuda_graph import CUDAGraphWrapperForward Context
ForwardContext
from atom.utils.forward_context import ForwardContextCompiler Backend
VllmBackend
from atom.utils.backends import VllmBackendCompiler Manager
CompilerManager
from atom.utils.backends import CompilerManagerCompiler Interface
CompilerInterface
from atom.utils.compiler_inferface import CompilerInterfaceInductor Adaptor
InductorAdaptor
from atom.utils.compiler_inferface import InductorAdaptorPiecewise Backend
PiecewiseBackend
from atom.utils.cuda_piecewise_backend import PiecewiseBackendCompile Decorator
@support_torch_compile
from atom.utils.decorators import support_torch_compileCustom Op Registration
direct_register_custom_op
from atom.utils.custom_register import direct_register_custom_opCompilation Levels at a Glance
Level
Name
Behavior
0
NO_COMPILATIONPure eager execution, no
torch.compile1
DYNAMO_AS_IS
torch.compilewithbackend="eager"2
DYNAMO_ONCE
torch.compilewith Inductor3
PIECEWISEPiecewise compilation with CUDA graph capture (production default)
CUDA Graph Modes at a Glance
Mode
Value
Behavior
NONE
0No graph capture
PIECEWISE
1Per-subgraph capture (default for level 3)
FULL
2Whole-model capture
FULL_DECODE_ONLY
(FULL, NONE)Full for decode, none for mixed batches
FULL_AND_PIECEWISE
(FULL, PIECEWISE)Full for decode, piecewise for prefill
1. Compilation Levels
ATOM provides four compilation levels via the CompilationLevel class in atom/config.py. The level is set through CompilationConfig.level and controls how torch.compile is applied to the model.
Level 0 – NO_COMPILATION
No torch.compile is applied. The model runs in pure eager mode. This is the simplest mode and is useful for debugging or when using models that are incompatible with torch.compile.
When level=0, the @support_torch_compile decorator sets self.do_not_compile = True and the model’s __call__ method bypasses compilation entirely, calling self.forward() directly.
Level 1 – DYNAMO_AS_IS
Uses torch.compile with backend="eager" and fullgraph=True. This runs Dynamo’s bytecode analysis and graph capture but does not apply any compiler optimizations. It is useful as a quick check to verify that a model is compatible with Dynamo’s tracing.
Like level 0, DYNAMO_AS_IS causes the decorator to set self.do_not_compile = True, since the model runner (rather than the decorator) handles the compilation at this level.
Level 2 – DYNAMO_ONCE
Uses torch.compile with the Inductor backend. The model graph is traced by Dynamo and compiled once through Inductor for optimized GPU kernel generation. The @support_torch_compile decorator’s custom dispatcher is activated when compilation_level >= DYNAMO_ONCE, allowing compiled bytecode to be dispatched directly after the first compilation without repeated guard evaluation.
Level 3 – PIECEWISE (Production Default)
The most advanced level. When Config.__post_init__ detects level == PIECEWISE, it:
Calls
CompilationConfig.set_splitting_ops_for_v1()to configure the splitting operations (default:["aiter.unified_attention_with_output", "aiter.mla_attention"]).Calls
Config._set_cudagraph_sizes()to compute the graph batch sizes.Sets
cudagraph_mode = CUDAGraphMode.PIECEWISE.Calls
CompilationConfig.init_with_cudagraph_sizes()to finalize compile sizes.
The VllmBackend is then used as the torch.compile backend. It splits the model graph into subgraphs at the splitting operations and compiles each subgraph independently via PiecewiseBackend.
2. CUDA Graph Modes
The CUDAGraphMode enum in atom/config.py controls how CUDA graphs are captured and replayed. CUDA graphs record a sequence of GPU operations and replay them with minimal CPU overhead, which is critical for low-latency decode steps.
NONE (value: 0)
No CUDA graph capture or replay. Every forward pass launches kernels individually. This mode is used during profiling, warmup, or when CUDA graphs are not supported.
PIECEWISE (value: 1)
The default mode for level 3 compilation. CUDA graphs are captured per subgraph (one for each piecewise-compiled region). Attention operations, which are split out by splitting_ops, run outside CUDA graphs because they may need dynamic metadata that changes between steps.
The CUDAGraphWrapper class wraps each subgraph with runtime_mode=CUDAGraphMode.PIECEWISE for capture and replay.
FULL (value: 2)
The entire model forward pass is captured as a single CUDA graph. This is suitable for small models or workloads with small, uniform batch sizes. Not all attention backends support full CUDA graph capture.
FULL_DECODE_ONLY (value: (FULL, NONE))
A tuple mode that applies different strategies to different batch types:
Decode batches: Captured with full CUDA graphs.
Mixed prefill-decode batches: Run without CUDA graphs.
This is useful for prefill/decode disaggregated (P/D) setups where decode latency matters more than prefill performance.
FULL_AND_PIECEWISE (value: (FULL, PIECEWISE))
A tuple mode combining both strategies:
Decode batches: Captured with full CUDA graphs.
Prefill and mixed batches: Captured with piecewise CUDA graphs.
This is described in the code as “the most performant mode for most models.”
Helper Methods
The CUDAGraphMode enum provides several helper methods for runtime dispatch:
Method |
Returns |
Purpose |
|---|---|---|
|
|
Returns the mode to use for decode batches. For tuple modes, returns the first element. |
|
|
Returns the mode to use for mixed batches. For tuple modes, returns the second element. |
|
|
Returns |
|
|
Returns |
|
|
Returns |
|
|
Returns the highest-valued mode across both decode and mixed modes. |
3. CUDA Graph Capture
CUDA graph capture is handled by ModelRunner.capture_cudagraph() in atom/model_engine/model_runner.py. This method is called at startup (under @torch.inference_mode()) to pre-capture graphs for a set of batch sizes.
Capture Flow
capture_cudagraph()
|
+-- Determine graph_bs list
| |-- If cudagraph_capture_sizes is set: use directly
| |-- If cuda_graph_sizes has 1 value N: [1, 2, 4, 8, 16, 32, ..., N]
| +-- If cuda_graph_sizes has >1 values: use the provided list
|
+-- Sort graph_bs in descending order (largest batch first)
|
+-- Assert max batch size <= max_num_seqs
|
+-- Initialize graph storage: self.graphs = dict()
|
+-- For each batch size bs (with progress bar on rank 0):
| |
| +-- Compute max_q_len (= mtp_k + 1 if MTP drafter, else 1)
| +-- Compute num_tokens = bs * max_q_len
| +-- Prepare cu_seqlens_q, positions
| +-- Build attn_metadata and context via attn_metadata_builder
| +-- Handle DP padding via get_dp_padding()
| +-- Set forward context (set_forward_context)
| +-- Warmup run: model(input_ids[:num_tokens], positions[:num_tokens])
| +-- Capture: torch.cuda.graph(graph, self.graph_pool, stream=gc.stream)
| +-- Share graph_pool across captures (set on first capture)
| +-- Store: self.graphs[(bs, max_q_len)] = graph
| +-- torch.cuda.synchronize()
|
+-- Sort graph_bs back to ascending order
+-- Return (elapsed_time, graph_bs)
Graph Keying
Each captured graph is stored in a dictionary keyed by a (graph_bs, max_q_len) tuple:
self.graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = dict()
graph_bs: The padded batch size used during capture.max_q_len: The maximum query length per sequence. For standard decode, this is1. For MTP (Multi-Token Prediction) speculative decoding, this ismtp_k + 1.
Graph Pool Sharing
The first captured graph creates a CUDA memory pool via graph.pool(). All subsequent captures share this pool through the self.graph_pool parameter, enabling memory reuse across different batch sizes.
if self.graph_pool is None:
self.graph_pool = graph.pool()
Default Capture Sizes
When cuda_graph_sizes has a single value (e.g., [512], the default), the capture sizes follow this pattern:
[1, 2, 4, 8] + [i for i in range(16, cuda_graph_sizes[0] + 1, 16)]
# Example with default 512:
# [1, 2, 4, 8, 16, 32, 48, 64, ..., 496, 512]
Graph Replay in run_model()
During inference, ModelRunner.run_model() decides whether to use eager execution or graph replay:
def run_model(self, input_ids):
forward_context = get_forward_context()
context = forward_context.context
bs = context.batch_size
is_prefill = context.is_prefill
positions = context.positions
if is_prefill or self.enforce_eager or bs > self.graph_bs[-1]:
# Eager path: prefills, enforce_eager mode, or oversized batches
hidden_states = self.model(input_ids, positions)
else:
# Graph replay path: decode batches within captured range
graph_bs = context.graph_bs
max_q_len = forward_context.attn_metadata.max_seqlen_q
graph_key = (graph_bs, max_q_len)
self.graphs[graph_key].replay()
num_tokens = context.batch_size * max_q_len
hidden_states = self.forward_vars["outputs"][:num_tokens]
return self.model.compute_logits(hidden_states), hidden_states
Key decisions:
Prefill: Always eager (variable sequence lengths make CUDA graphs impractical).
Decode with bs <= max captured size: Replay the pre-captured graph.
Decode with bs > max captured size: Fall back to eager execution.
enforce_eager=True: Always eager, regardless of batch size.
4. Piecewise Compilation
Piecewise compilation splits the model’s computation graph at specified operations and compiles each subgraph independently. This enables CUDA graph capture for the compilable parts while leaving incompatible operations (primarily attention) to run eagerly.
Splitting Operations
The splitting_ops field in CompilationConfig defines which operations split the graph. When set_splitting_ops_for_v1() is called (automatically at level 3), the default splitting ops are:
["aiter.unified_attention_with_output", "aiter.mla_attention"]
These attention operations are split out because:
They require dynamic metadata (sequence lengths, block tables) that changes per step.
Some attention backends are not compatible with CUDA graph capture.
Attention kernels are already highly optimized, so Inductor compilation provides minimal additional benefit.
Compilation Pipeline
The VllmBackend.__call__ method orchestrates the piecewise compilation:
Graph splitting:
split_graph()divides the traced model graph at the splitting operations into a sequence ofSplitItemobjects, each containing a subgraph.Submodule identification: Subgraphs that are not splitting operations are identified as candidates for compilation.
Dynamic-shape compilation:
PiecewiseCompileInterpreterruns the split graph with fake inputs and compiles each non-splitting subgraph viaCompilerManager.compile()for a general (dynamic) shape.Backend creation: For each compiled subgraph, a
PiecewiseBackendinstance is created. It holds:compiled_graph_for_general_shape: The Inductor-compiled graph for dynamic shapes.concrete_size_entries: A dictionary mapping specific runtime shapes toConcreteSizeEntryobjects for shape-specialized compilation.
Runtime dispatch: When
PiecewiseBackend.__call__is invoked:On the first run, it uses the general-shape compiled graph.
For subsequent runs, if the runtime shape is in
compile_sizes, it lazily compiles a shape-specialized version viaCompilerManager.compile()and caches it.For shapes not in
compile_sizes, it falls back to the general-shape compiled graph.
Cache Management
The CompilerManager caches compiled graphs using a key of (runtime_shape, graph_index, backend_name). The cache is stored in a Python file (vllm_compile_cache.py) at the local cache directory (~/.cache/atom/torch_compile_cache/<hash>/rank_<i>/<prefix>/).
On subsequent runs with the same model and configuration, compiled graphs are loaded from the cache, bypassing Inductor compilation entirely.
5. Forward Context & Stateless Dispatch
The ForwardContext dataclass in atom/utils/forward_context.py provides a module-level global mechanism for passing metadata to layers during the forward pass. This is critical for CUDA graphs because captured graphs cannot accept new arguments – all dynamic metadata must be accessible through a side channel.
ForwardContext Fields
Field |
Type |
Purpose |
|---|---|---|
|
|
Layers that should skip compilation (from |
|
|
Attention-specific metadata (sequence lengths, block tables, etc.) |
|
|
KV cache tensors for each layer |
|
|
Basic forward pass context (positions, is_prefill, batch_size, graph_bs) |
|
|
Data-parallel metadata (token counts across DP ranks) |
|
|
Speculative decoding metadata (draft tokens, logits indices) |
Lifecycle
The forward context follows a set-use-reset lifecycle:
Set: Before each forward pass,
set_forward_context()is called with attention metadata, the ATOM config, aContextobject, and optional DP/speculative decoding metadata.Access: During the forward pass, any layer can call
get_forward_context()to retrieve the current metadata without needing it passed as a function argument. This is used by both eager execution and CUDA graph replay paths.Reset: After the forward pass,
reset_forward_context()replaces the global context with an emptyForwardContext().
Context Dataclass
The Context object carries the most frequently accessed per-step state:
@dataclass
class Context:
positions: torch.Tensor # Token position IDs
is_prefill: bool = False # Whether this is a prefill step
batch_size: int = 0 # Number of sequences in the batch
graph_bs: int = 0 # Padded batch size for graph lookup
is_draft: bool = False # Whether this is a draft model forward
The graph_bs field is particularly important for CUDA graph dispatch: it holds the padded batch size that maps to a pre-captured graph key.
Integration with CUDA Graphs
For ModelRunner’s direct CUDA graph path (non-piecewise), the forward context is set before run_model() via set_forward_context(), and run_model() reads context.graph_bs and attn_metadata.max_seqlen_q to look up the correct pre-captured graph.
For the piecewise path, CUDAGraphWrapper (in atom/utils/cuda_graph.py) expects batch_descriptor and cudagraph_runtime_mode fields on the forward context to decide whether to capture, replay, or run eagerly:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
cudagraph_runtime_mode = forward_context.cudagraph_runtime_mode
Note: The piecewise
CUDAGraphWrapperintegration is under development. Thebatch_descriptorandcudagraph_runtime_modefields are expected byCUDAGraphWrapper.__call__()but are not currently defined on theForwardContextdataclass. The per-subgraph wrapping inbackends.pyis also currently commented out. The direct CUDA graph path inModelRunneris the active production path.
6. Compiler Backend
CompilerManager
CompilerManager in atom/utils/backends.py manages the full compilation lifecycle:
Initialization: Creates a
CompilerInterfaceviamake_compiler(). UsesInductorStandaloneAdaptorfor PyTorch 2.8+ orInductorAdaptorfor earlier versions.Caching: Maintains a dictionary mapping
(runtime_shape, graph_index, backend_name)to compiler-specific handles. Caches are serialized to a Python file usingpprint.Compile-or-load: On each call to
compile(), first attemptsload()from the cache. On miss, delegates to the compiler and stores the result.
CompilerInterface
CompilerInterface in atom/utils/compiler_inferface.py (note the typo in the filename) defines the abstract interface that all compiler backends must implement:
Method |
Purpose |
|---|---|
|
Set up cache directories for the compiler |
|
Generate a hash of compiler-specific state for cache invalidation |
|
Compile a graph, returning |
|
Load a previously compiled graph from the handle |
InductorAdaptor
The default compiler for PyTorch < 2.8. Uses torch._inductor.compile_fx.compile_fx and monkey-patches several internal functions to:
Extract the compilation hash for caching.
Provide a dummy shape environment (
AlwaysHitShapeEnv) so Inductor cache lookups succeed outside of Dynamo’s tracing context.Force caching of graphs that Inductor would normally refuse to cache.
When runtime_shape is an integer (specific batch size), it enables max_autotune and coordinate_descent_tuning for Triton kernel parameter optimization.
InductorStandaloneAdaptor
The preferred compiler for PyTorch 2.8+. Uses torch._inductor.standalone_compile which provides a cleaner interface without the monkey-patching required by InductorAdaptor. Compiled artifacts are saved to disk in “unpacked” format and can be loaded directly.
VllmBackend
VllmBackend in atom/utils/backends.py serves as the torch.compile backend for level 3 (piecewise) compilation. When Dynamo calls it:
Computes a cache directory hash from config, traced files, and compiler state.
Splits the graph at
splitting_opsusingsplit_graph().Runs
PiecewiseCompileInterpreterto compile each non-splitting subgraph.Saves the computation graph to
computation_graph.pyfor debugging.Returns the stitching graph module (
split_gm) as the callable.
If cudagraph_copy_inputs is True, it wraps the callable to copy input tensors into static buffers before each call, ensuring CUDA graph input address stability.
@support_torch_compile Decorator
The @support_torch_compile decorator in atom/utils/decorators.py augments a model class to support torch.compile:
Class modification: Adds
TorchCompileWrapperWithCustomDispatcheras a base class and overrides__init__and__call__.Dynamic shape marking: On the first compilation, it inspects the
forwardmethod signature, identifiestorch.Tensorarguments, and callstorch._dynamo.mark_dynamic()to mark their batch dimensions as dynamic.Custom dispatch: After the first compilation, if
use_custom_dispatcheris True (levels >= 2), subsequent calls bypass Dynamo’s guard mechanism and dispatch directly to the compiled bytecode viadispatch_to_code(0).Safety check: The bytecode hook checks for
updatein the compiled code’sco_names, raising an error if the model modifiesnn.Modulebuffers during the forward pass (which would cause silent errors with CUDA graphs).
Custom Op Registration
direct_register_custom_op() in atom/utils/custom_register.py registers custom operators with PyTorch’s torch.library system:
direct_register_custom_op(
op_name="my_op",
op_func=my_kernel,
mutates_args=["output"],
fake_impl=my_fake_impl,
)
This registers the op under the "aiter" library namespace (e.g., aiter.my_op), making it visible to Dynamo’s tracing. The fake_impl is used during tracing to compute output shapes without executing the real kernel. The dispatch_key defaults to "CUDA" for GPU operations.
Registered custom ops can be used as splitting_ops in piecewise compilation (e.g., "aiter.unified_attention_with_output").
7. Configuration Options
All compilation-related configuration fields from CompilationConfig:
Field |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Compilation level (0-3). See Section 1. |
|
|
|
Whether CUDA graph capture is enabled. |
|
|
|
Explicit list of batch sizes to capture. Overrides |
|
|
|
Controls auto-generated capture sizes. 1 value = generate pattern; >1 values = use directly. |
|
|
|
CUDA graph mode. Set to |
|
|
|
Operations that split the graph for piecewise compilation. Auto-set at level 3. |
|
|
|
Copy input tensors to static buffers for CUDA graph address stability. Only effective in |
|
|
|
Whether to use the Inductor compiler backend. |
|
|
|
Specific sizes to compile with Inductor. Supports |
|
|
|
Additional Inductor configuration (e.g., |
|
|
|
Path to dump debug information (traced graphs, decompiled code). |
|
|
|
Custom cache directory. Auto-generated if empty ( |
Related fields on Config:
Field |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
Force eager execution, skip all compilation and CUDA graphs. |
|
|
|
Final list of batch sizes for CUDA graph capture (computed from |
|
|
|
The compilation configuration dataclass. |
8. Decision Tree
Use this decision tree to select the right compilation level and CUDA graph mode for your workload:
Is the model supported by torch.compile?
|
+-- No --> Level 0 (NO_COMPILATION)
| enforce_eager=True
|
+-- Yes
|
+-- Debugging / profiling?
| |
| +-- Yes --> Level 0 (NO_COMPILATION)
|
+-- Quick compatibility check?
| |
| +-- Yes --> Level 1 (DYNAMO_AS_IS)
|
+-- Want Inductor optimization without CUDA graphs?
| |
| +-- Yes --> Level 2 (DYNAMO_ONCE)
|
+-- Production deployment
|
+-- Level 3 (PIECEWISE) [recommended]
|
+-- Standard serving --> cudagraph_mode=PIECEWISE (default)
|
+-- Small model / uniform batches --> cudagraph_mode=FULL
|
+-- P/D disaggregated (decode instance) --> cudagraph_mode=FULL_DECODE_ONLY
|
+-- Maximum performance --> cudagraph_mode=FULL_AND_PIECEWISE
Common Configurations
Default production setup (level 3, piecewise CUDA graphs):
CompilationConfig(level=3)
# Automatically sets:
# splitting_ops = ["aiter.unified_attention_with_output", "aiter.mla_attention"]
# cudagraph_mode = CUDAGraphMode.PIECEWISE
# cuda_graph_sizes = [512]
# graph_bs = [1, 2, 4, 8, 16, 32, ..., 512]
Custom capture sizes:
CompilationConfig(level=3, cudagraph_capture_sizes=[1, 2, 4, 8])
Debugging with full eager execution:
Config(model="...", enforce_eager=True)
# or
CompilationConfig(level=0)
Inductor with debug dump:
CompilationConfig(level=3, debug_dump_path="/tmp/atom_debug")
# Dumps traced graphs and decompiled code to /tmp/atom_debug/rank_0/
Source Files
File |
Description |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|