ATOM Model Operations Guide
ATOM (AiTer Optimized Model) wraps AITER kernels with model-level abstractions for LLM inference on AMD ROCm/HIP GPUs. This guide documents every operator class in atom/model_ops/, their AITER kernel mappings, quantization paths, and fused kernel chains.
Quick Reference
ATOM Class |
File |
AITER Kernel / Import |
Purpose |
|---|---|---|---|
|
|
|
Quantized linear dispatch |
|
|
(inherits |
Column-sharded TP linear |
|
|
(inherits |
Row-sharded TP linear |
|
|
(inherits |
Fused Q/K/V projection |
|
|
(inherits |
Merged gate+up projection |
|
|
|
Unified attention entry |
MHA |
|
|
Multi-head attention |
|
|
|
Multi-head latent attention |
|
|
|
Mixture of experts |
|
|
|
RMS normalization |
|
|
|
Layer normalization |
|
|
|
SiLU gated activation |
|
|
|
Vocab embedding |
|
|
|
LM output head |
|
|
|
Rotary position embedding |
|
|
|
Token sampling |
|
|
Triton |
Speculative decoding |
1. AITER Integration Overview
ATOM is a thin model-level inference engine. Every compute-heavy operation delegates to an AITER kernel. The general pattern is:
An ATOM
nn.Moduleowns model weights and configuration.Its
forward()method selects the appropriate AITER function based on quantization type, parallelism settings, and phase (prefill vs. decode).Results are optionally reduced across tensor-parallel (TP) or data-parallel (DP) groups.
AITER Kernel Mapping Table
ATOM Wrapper |
AITER Function / Import Path |
Backend Type |
|---|---|---|
|
|
hipBLASLt |
|
|
hipBLASLt |
|
|
CK |
|
|
CK |
|
|
CK |
|
|
CK |
MHA prefill |
|
ASM / CK |
MHA decode (ASM) |
|
ASM |
MHA decode (persistent ASM) |
|
ASM |
MHA decode (Triton) |
|
Triton |
MHA prefill (Triton unified) |
|
Triton |
MLA decode |
|
ASM |
MLA prefill |
|
ASM |
MLA KV cache |
|
CK |
RoPE |
|
Triton |
RMSNorm |
|
CK |
SiLU+Mul |
|
CK |
TopK routing |
|
CK |
Sampling |
|
CK |
FusedMoE |
|
CK |
ASM MoE |
|
ASM |
Quantization |
|
CK / Triton |
2. Linear Operations
All linear layers inherit from LinearBase in atom/model_ops/linear.py.
2.1 Class Hierarchy
LinearBase (nn.Module)
+-- ReplicatedLinear # No TP sharding
| +-- MergedReplicatedLinear
+-- ColumnParallelLinear # tp_dim=0, shard output
| +-- QKVParallelLinear # Fused Q/K/V with per-head sharding
+-- MergedColumnParallelLinear # tp_dim=0, merged gate+up
+-- RowParallelLinear # tp_dim=1, shard input, optional all-reduce
2.2 Quantization Dispatch
LinearBase.forward() dispatches to different GEMM kernels based on QuantType:
|
Weight dtype |
GEMM Kernel |
Scale Shape |
|---|---|---|---|
|
BF16/FP16 |
|
None |
|
FP8 |
|
|
|
INT8 |
|
|
|
FP8 |
|
|
|
FP8 |
|
|
|
MXFP4 ( |
|
|
When x_scale is not provided, the input is dynamically quantized via get_hip_quant(quant_type).
2.3 Tensor Parallel Sharding
ColumnParallelLinear (
tp_dim=0): Shards weight rows (output dimension) across GPUs. Each GPU ownsoutput_size / tp_sizerows.RowParallelLinear (
tp_dim=1): Shards weight columns (input dimension). Ifreduce_results=True, output is all-reduced across TP group.QKVParallelLinear: Extends
ColumnParallelLinearwith per-head sharding. Q heads are evenly divided; KV heads are either divided or replicated whennum_kv_heads < tp_size.MergedColumnParallelLinear: Handles gate and up projections merged into a single weight with
output_sizesas a list (e.g.,[intermediate_size, intermediate_size]).
2.4 Weight Processing
After loading, process_weights_after_loading() handles:
e4m3fn to e4m3fnuz normalization (AMD FP8 format conversion).
Weight reshuffling via
shuffle_weights()for pre-shuffled GEMM kernels.Scale reshuffling via
fp4_utils.e8m0_shuffle()for MXFP4 block scales.Per-tensor requantization via
requantize_with_max_scale()when multiple output partitions have separate scales.
3. Attention Operations
3.1 Base: Attention (base_attention.py)
The top-level Attention class in base_attention.py is a dispatcher. It:
Selects the backend via
get_attn_backend()fromatom/utils/selector.py.Instantiates the backend’s implementation class (
impl_cls).Registers itself in
compilation_config.static_forward_contextunderlayer_name.On
forward(), callstorch.ops.aiter.unified_attention_with_output_base, which is a custom op decorated with@mark_spliting_op– this preventstorch.compilefrom tracing into attention internals, enabling full-graph capture.
Backend selection logic (in selector.py):
Condition |
Backend Class |
Implementation |
|---|---|---|
|
|
|
|
|
|
3.2 Multi-Head Attention (attention_mha.py)
The MHA Attention class handles standard models (Llama, Qwen3, Mixtral, etc.).
Forward flow:
Reshape Q, K, V to
[num_tokens, num_heads, head_dim].Apply RoPE + KV cache write via
rope_cache().Dispatch to the appropriate backend via
dispatch_backend().
RoPE + KV cache paths:
Condition |
Kernel Chain |
|---|---|
|
|
Triton path ( |
|
ASM path + |
|
Attention dispatch:
Phase |
Condition |
Method |
AITER Kernel |
|---|---|---|---|
Prefill |
Always |
|
|
Decode |
|
|
|
Decode |
|
|
|
Decode |
Default |
|
|
The use_triton_attn flag is set when sliding_window != -1 or head_dim != 128.
3.3 Multi-head Latent Attention (attention_mla.py)
MLAAttention implements DeepSeek’s MLA with a compressed KV representation. Key data structures:
@dataclass
class MLAModules:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
qk_head_dim: int
v_head_dim: int
rotary_emb: torch.nn.Module
q_proj: Optional[torch.nn.Module]
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
indexer: Optional[torch.nn.Module]
Forward flow:
If prefill and not sparse: Standard MHA-style prefill with
flash_attn_varlen_func, preceded bykv_b_projGEMM to produce K_nope and V from compressedkv_c_normed.Otherwise: Fused Q projection + K up-projection via batched FP8/FP4 BMM (
_q_proj_and_k_up_proj), then:fused_qk_rope_concat_and_cache_mlawrites to KV cache.Decode:
mla_decode_fwd(ASM persistent MLA kernel).Prefill (sparse):
mla_prefill_fwd.
V up-projection + O projection via batched BMM (
_v_up_proj_and_o_proj).
Batched GEMM backends for MLA projections:
Condition |
Kernel |
|---|---|
|
|
Default |
|
Prefill GEMM optimizations (for kv_b_proj):
Condition |
Kernel |
|---|---|
|
|
|
|
Default |
|
3.4 Backend Abstraction (attentions/backends.py)
The AttentionBackend abstract class defines three required methods:
get_name()– Returns backend identifier string.get_builder_cls()– Returns theAttentionMetadataBuildersubclass.get_impl_cls()– Returns the attention implementation class.
CommonAttentionBuilder provides shared metadata preparation (slot mapping, block tables, cumulative sequence lengths) used by both AiterBackend and AiterMLABackend.
3.5 KV Cache Operations
Operation |
AITER Kernel |
Used By |
|---|---|---|
Standard KV cache write |
|
MHA (BF16 KV) |
FP8 KV cache write |
|
MHA (FP8 KV) |
MLA KV cache write |
|
MLA prefill |
Fused QK RoPE + MLA cache |
|
MLA decode |
4. Mixture of Experts (MoE)
4.1 FusedMoE Class (moe.py)
FusedMoE is the top-level MoE module. It handles:
Expert routing via
select_experts().Weight creation and quantization dispatch via
quant_method.Tensor/Expert/Data parallelism via
FusedMoEParallelConfig.Optional shared expert fusion and MORI communication.
Constructor parameters:
FusedMoE(
num_experts: int, # Global number of experts
top_k: int, # Experts per token
hidden_size: int, # Input hidden dimension
intermediate_size: int, # Expert intermediate dimension
reduce_results: bool, # Whether to all-reduce output
renormalize: bool, # Renormalize routing weights
use_grouped_topk: bool, # Use grouped top-k (DeepSeek)
activation: ActivationType, # Silu, Gelu, Swiglu, etc.
...
)
4.2 Quantization Methods
FusedMoE selects a quant_method at construction time:
Quant Config |
Method Class |
GEMM Kernel |
|---|---|---|
|
|
|
FP8 ( |
|
|
FP8 compressed-tensors |
|
|
MXFP4 ( |
|
|
The ASM MoE path (asm_moe from aiter.fused_moe_bf16_asm) is used by FP8 methods and supports a16 mode where activations remain in BF16/FP16 while weights are FP8/INT8.
4.3 TopK Routing (topK.py)
Routing Function |
AITER Kernel |
Used For |
|---|---|---|
|
|
Standard top-k (Mixtral) |
|
|
Grouped top-k (DeepSeek) |
|
|
Biased grouped top-k (DeepSeek V3) |
Shared expert fusion: When is_rocm_aiter_fusion_shared_expert_enabled() returns True, the top-k buffers are extended with shared expert IDs appended after routed expert IDs. This allows shared expert computation to be fused into the same MoE kernel call. The metadata is initialized via init_aiter_topK_meta_data().
4.4 FusedMoEParallelConfig
@dataclass
class FusedMoEParallelConfig:
tp_size: int # Tensor parallel size
dp_size: int # Data parallel size
ep_size: int # Expert parallel size
tp_rank: int
dp_rank: int
ep_rank: int
use_ep: bool # Whether expert parallelism is active
local_ep_size: int # Local EP size (GPUs per node * TP)
Key properties:
use_all2all_kernels:Truewhendp_size > 1, EP is enabled, and MORI is available.use_mori_kernels: AlwaysTrue(currently).
4.5 MORI Integration (fused_moe/mori_prepare_finalize.py)
MORI (MoE Router Infrastructure) provides all-to-all communication kernels for expert parallelism. MoriPrepareAndFinalize implements:
prepare(): Dispatches tokens to remote experts viamori_op.dispatch(). Optionally quantizes activations to FP8 before dispatch.finalize(): Combines expert outputs viamori_op.combine()and copies results back.
The FusedMoEModularKernel orchestrates the prepare-compute-finalize pipeline.
4.6 MoE Quantization Config (fused_moe/config.py)
FusedMoEQuantConfig describes activation and weight quantization for MoE layers:
@dataclass
class FusedMoEQuantConfig:
_a1: FusedMoEQuantDesc # First activation (input to gate_up)
_a2: FusedMoEQuantDesc # Second activation (input to down_proj)
_w1: FusedMoEQuantDesc # gate_up_proj weights
_w2: FusedMoEQuantDesc # down_proj weights
Factory functions:
fp8_w8a8_moe_quant_config()– FP8 weights and activations.mxfp4_w4a16_moe_quant_config()– MXFP4 weights, unquantized activations.FUSED_MOE_UNQUANTIZED_CONFIG– No quantization.
4.7 Triton MoE Fallback (fused_moe_triton.py)
triton_kernel_moe_forward() provides a Triton-based MoE path using the triton_kernels library. It uses routing() for expert assignment and matmul_ogs() for the expert GEMM. This path is currently used for MXFP4 MoE on GFX94x hardware.
5. Normalization
5.1 RMSNorm (layernorm.py)
RMSNorm supports multiple forward paths depending on configuration flags:
Condition |
Kernel / Path |
Returns |
|---|---|---|
|
|
Padded output |
|
|
(output, residual) |
|
|
(output, residual) |
|
|
(FP8 output, scale) |
|
|
(MXFP4 output, scale) |
Default, no residual |
|
Output |
Default, with residual |
|
(output, residual) |
Constructor parameters:
RMSNorm(
dim: int,
eps: float = 1e-6,
x_pad_to_multiple: int = 0,
fused_allreduce: bool = False,
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
)
5.2 LayerNorm (layernorm.py)
LayerNorm wraps layernorm2d_fwd and layernorm2d_fwd_with_add (with bias support):
LayerNorm(dim: int, eps: float = 1e-6)
Without residual:
layernorm2d_fwd(x, weight, bias, eps)With residual:
layernorm2d_fwd_with_add(out, x, residual, residual_out, weight, bias, eps)
6. Activation Functions
6.1 SiluAndMul (activation.py)
SiluAndMul computes SiLU(x_first_half) * x_second_half. It splits the last dimension in half.
Condition |
Kernel |
Output |
|---|---|---|
|
|
|
|
|
|
Default |
|
BF16 output |
Constructor:
SiluAndMul(
fused_quant: bool = False,
quant_config: Optional[QuantizationConfig] = None,
)
7. Embedding & Output Head
7.1 VocabParallelEmbedding (embed_head.py)
Partitions the vocabulary across TP ranks. Each rank holds num_embeddings / tp_size rows.
Forward:
Mask input token IDs to this rank’s partition range
[vocab_start_idx, vocab_end_idx).F.embedding()on local partition.Zero out out-of-range positions.
all_reduce()across TP group.
7.2 ParallelLMHead (embed_head.py)
Extends VocabParallelEmbedding for the output projection. Key differences:
Forward extracts only the last token per sequence during prefill (via
cu_seqlens_q[1:] - 1).Uses
tgemm.mm(x, self.weight, self.bias)for the logit computation (notF.linear).Calls
tensor_model_parallel_all_gather()to gather logits across TP ranks.
8. Rotary Position Embedding (RoPE)
8.1 RotaryEmbedding (rotary_embedding.py)
Precomputes cos/sin caches at initialization and applies RoPE in-place.
Constructor:
RotaryEmbedding(
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: float,
is_neox_style: bool = True,
dtype: Optional[torch.dtype] = None,
)
Forward: Calls aiter.rope_cached_positions_2c_fwd_inplace(query_, key_, cos, sin, positions, rotate_style, ...) which applies RoPE to Q and K tensors in-place using precomputed caches indexed by position IDs.
8.2 get_rope() Factory
get_rope(head_size, rotary_dim, max_position, base, rope_scaling=None)
Returns a cached RotaryEmbedding instance. Currently rope_scaling must be None.
8.3 Integration in Attention
MHA (
attention_mha.py): RoPE is applied during therope_cache()phase, either via the fusedfused_qk_norm_rope_cache_quant_shufflekernel, viafused_qk_rope_reshape_and_cache, or via standalonerotary_emb(position, q, k).MLA (
attention_mla.py): RoPE is applied toq_peandk_ropetensors. During decode, this is fused intofused_qk_rope_concat_and_cache_mla. During prefill, it is applied viaself.rotary_emb(positions, prefill_q_pe, k_rope).
9. Sampling
9.1 Sampler (sampler.py)
Unified sampling supporting both greedy (temperature=0) and random (temperature>0) sampling in a single kernel call.
Forward:
def forward(self, logits, temperatures) -> sampled_tokens:
mixed_sample_outer_exponential(sampled_tokens, logits, exponential, temperatures, eps)
aiter.mixed_sample_outer_exponential performs temperature-scaled exponential sampling: it divides logits by temperature, then uses the Gumbel-max trick with pre-generated exponential random variates.
Fallback methods (currently unreachable due to early return):
greedy_sample():aiter.ops.triton.topk.topk(logits, 1)random_sample():aiter.ops.triton.softmax.softmax(logits)followed by exponential sampling andtopk.
9.2 RejectionSampler (rejection_sampler.py)
Implements rejection sampling for speculative decoding (MTP). Given draft token IDs and target model logits:
Computes
target_argmax = target_logits.argmax(dim=-1).Runs a Triton kernel
rejection_greedy_sample_kernelthat sequentially compares draft tokens against target argmax, accepting until first mismatch.On full acceptance, appends the bonus token.
Returns
(output_token_ids, num_bonus_tokens).
10. Fused Kernel Chains
ATOM uses fused kernels to reduce memory traffic by combining multiple operations into a single kernel launch.
Fused Operation |
Components |
Controlled By |
AITER Kernel |
|---|---|---|---|
RMSNorm + FP8 quant |
RMSNorm, per-tensor FP8 static quant |
|
|
RMSNorm + MXFP4 quant |
RMSNorm, per-1x32 MXFP4 quant |
|
|
RMSNorm + add + pad |
Residual add, RMSNorm, output padding |
|
|
AllReduce + RMSNorm |
TP all-reduce, RMSNorm |
|
|
SiLU + mul + FP8 quant |
SiLU activation, multiply, FP8 quant |
|
|
SiLU + mul + MXFP4 quant |
SiLU activation, multiply, MXFP4 quant |
|
|
QK norm + RoPE + cache + quant |
Q/K norm, RoPE, KV cache write, optional FP8 quant, weight shuffle |
|
|
RoPE + reshape + cache |
RoPE, K reshape, KV cache write |
Triton attention path |
|
QK RoPE + MLA cache |
Q RoPE, KV concat, MLA cache write, FP8 quant |
MLA decode path |
|
GEMM + split + cat (FP4) |
KV_b_proj GEMM, split K_nope/V, cat K_rope |
|
|
GEMM + split + cat (FP8) |
KV_b_proj GEMM, split K_nope/V, cat K_rope |
|
|
FP8 BMM + RoPE + cache (MLA) |
Batched FP8 BMM, RoPE, MLA KV cache write |
MLA decode with FP8 |
|
FP4 BMM + RoPE + cache (MLA) |
Batched FP4 BMM, RoPE, MLA KV cache write |
MLA decode with MXFP4 |
|
Source Files
atom/model_ops/
File |
Description |
|---|---|
|
|
|
|
|
|
|
Top-level |
|
MHA implementation: prefill (flash), decode (ASM/Triton paged attention) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Helper utilities: |
atom/model_ops/attentions/
File |
Description |
|---|---|
|
|
|
|
|
|
atom/model_ops/fused_moe/
File |
Description |
|---|---|
|
|
|
|
|
|
|
MoE utility functions |
atom/utils/
File |
Description |
|---|---|
|
|