ATOM Model Support Guide
ATOM (AiTer Optimized Model) is AMD’s lightweight LLM inference engine built on AITER kernels for ROCm/HIP GPUs. This guide covers the supported model architectures, weight loading, and how to add new models.
Quick Reference
The model registry lives in atom/model_engine/model_runner.py as support_model_arch_dict:
support_model_arch_dict = {
"Qwen3ForCausalLM": "atom.models.qwen3.Qwen3ForCausalLM",
"Qwen3MoeForCausalLM": "atom.models.qwen3_moe.Qwen3MoeForCausalLM",
"LlamaForCausalLM": "atom.models.llama.LlamaForCausalLM",
"MixtralForCausalLM": "atom.models.mixtral.MixtralForCausalLM",
"DeepseekV3ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"DeepseekV32ForCausalLM": "atom.models.deepseek_v2.DeepseekV2ForCausalLM",
"GptOssForCausalLM": "atom.models.gpt_oss.GptOssForCausalLM",
"Glm4MoeForCausalLM": "atom.models.glm4_moe.Glm4MoeForCausalLM",
}
ATOM resolves the HuggingFace architectures field from a model’s config.json against this dictionary. If the architecture string matches a key, ATOM imports and instantiates the corresponding class.
1. Supported Model Architectures
HF Architecture |
ATOM Module |
ATOM Class |
MoE |
MLA |
Key Features |
|---|---|---|---|---|---|
|
|
|
No |
No |
GQA, QK norm, RoPE |
|
|
|
Yes |
No |
GQA, QK norm, FusedMoE, sparse+dense layer mixing, QK norm+RoPE+cache+quant fusion |
|
|
|
No |
No |
GQA, RoPE, fused RMSNorm+quant, fused SiLU+mul+quant |
|
|
|
Yes |
No |
GQA, RoPE, FusedMoE with TP sharding |
|
|
|
Yes |
Yes |
MLA attention, LoRA-compressed QKV, FusedMoE with shared experts, FP4/FP8 fused kernels |
|
|
|
Yes |
Yes |
Same as above with V3.2 index-based top-k routing |
|
|
|
Yes |
No |
GQA, RoPE, sliding window attention (every other layer), attention sinks, bias in QKV and MoE |
|
|
|
Yes |
No |
GQA, partial RoPE (0.5 factor), QK norm, shared+routed experts, sigmoid scoring, grouped top-k |
Note: DeepSeekMTP (atom.models.deepseek_mtp.DeepSeekMTP) is not in the registry – it is used exclusively as a speculative draft model for DeepSeek multi-token prediction and is loaded separately.
2. Model Architecture Details
Qwen3 (Qwen3ForCausalLM)
Architecture: Dense transformer with Grouped-Query Attention (GQA).
Layer structure:
Qwen3DecoderLayercontainingQwen3Attention+Qwen3MLP.Attention:
QKVParallelLinearfor fused QKV projection, per-head QK RMSNorm (q_norm,k_norm), RoPE,RowParallelLinearfor output projection.MLP:
MergedColumnParallelLinearfor gate+up projection, SiLU activation,RowParallelLinearfor down projection.Normalization: RMSNorm on input and post-attention.
Qwen3-MoE (Qwen3MoeForCausalLM)
Architecture: Mixture-of-Experts transformer with GQA.
Layer structure:
Qwen3MoeDecoderLayercontainingQwen3MoeAttention+ eitherQwen3MoeSparseMoeBlock(MoE layers) orQwen3MoeMLP(dense layers, controlled bymlp_only_layersanddecoder_sparse_step).Attention: Same QKV structure as Qwen3 with QK norm. Supports QK norm + RoPE + cache + quant fusion when
ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSIONis set – this precomputes a jointcos_sin_cacheand passesq_norm/k_normto theAttentionmodule.MoE:
FusedMoEwithReplicatedLineargate router. Supports allreduce+RMSNorm fusion (ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION).Normalization: RMSNorm with optional fused allreduce.
Llama (LlamaForCausalLM)
Architecture: Dense transformer with GQA. Covers Llama 2/3 and compatible architectures (InternLM, Mistral-Nemo via optional
head_dim).Layer structure:
LlamaDecoderLayercontainingLlamaAttention+LlamaMLP.Attention:
QKVParallelLinear, RoPE (NeoX or original style based on GGUF), per-layer sliding window support vialayer_typesconfig.MLP:
MergedColumnParallelLinearfor gate+up, SiLU+mul activation,RowParallelLinearfor down.Fused optimizations: Controlled by environment variables:
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT– fuses RMSNorm with FP8/MXFP4 quantization.ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT– fuses SiLU+mul activation with quantization.
Pipeline parallelism: Full PP support with
PPMissingLayerplaceholders andIntermediateTensorsfor cross-stage communication. Supports auxiliary hidden state extraction for speculative decoding.
Mixtral (MixtralForCausalLM)
Architecture: Sparse Mixture-of-Experts with GQA.
Layer structure:
MixtralDecoderLayercontainingMixtralAttention+MixtralMoE.Attention: Standard GQA with
QKVParallelLinear, RoPE (NeoX style),RowParallelLinear.MoE:
MixtralMoEwrapsReplicatedLineargate +FusedMoE. Experts are sharded across TP ranks with full reduce. Gate checkpoint names usew1/w2/w3convention (mapped togate_proj/down_proj/up_proj).Normalization: RMSNorm.
DeepSeek V2/V3 (DeepseekV2ForCausalLM)
Architecture: MoE transformer with Multi-head Latent Attention (MLA).
Layer structure:
DeepseekV2DecoderLayercontainingDeepseekV2MLAAttention+ eitherDeepseekV2MoE(MoE layers) orDeepseekV2MLP(dense layers).MLA Attention: Uses LoRA-compressed QKV (
q_lora_rank,kv_lora_rank), separateqk_nope_head_dimandqk_rope_head_dimfor non-positional and rotary-embedded components. Backed byMLAModulesfromatom.model_ops.attention_mla.MoE:
DeepseekV2MoEwith routed + shared experts. Supports shared expert fusion (is_rocm_aiter_fusion_shared_expert_enabled), routed scaling factor fusion (is_rocm_aiter_fuse_routed_scaling_factor), and grouped top-k routing.Fused optimizations:
ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION– fuses input RMSNorm with FP8/FP4 quantization.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION– fuses QK norm with quantization.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION– fuses allreduce with RMSNorm.Dedicated Triton kernels for FP8 MQA logits (
fp8_mqa_logits), paged MQA logits (deepgemm_fp8_paged_mqa_logits), and fused RMSNorm+quantization (_fuse_rmsnorm_quant).
V3.2 extension:
DeepseekV32ForCausalLMis an alias. TheDeepseekV2Modeldetects V3.2 viaconfig.index_topkand allocates antopk_indices_bufferfor index-based routing.Note:
DeepseekV3ForCausalLMis a subclass ofDeepseekV2ForCausalLM(pass-through, no override).
DeepSeek MTP (DeepSeekMTP)
Architecture: Multi-Token Prediction draft model for speculative decoding.
Layer structure:
DeepSeekMultiTokenPredictorcontaining one or moreDeepSeekMultiTokenPredictorLayer, each withenorm(embedding norm),hnorm(hidden state norm),eh_proj(linear projection joining embedded+hidden),mtp_block(aDeepseekV2DecoderLayer), and aSharedHead(norm + LM head).Usage: Not registered in
support_model_arch_dict. Loaded separately withspec_decode=Trueinload_model(), which invokesrewrite_spec_layer_name()to remap MTP weight names (e.g., adding.mtp_block.prefix for transformer layer weights, remappingembed_tokensto top-level).MTP layers start at
config.num_hidden_layers(i.e., the layer indices following the main model layers).
GPT-OSS (GptOssForCausalLM)
Architecture: MoE transformer with GQA and alternating sliding window attention.
Layer structure:
TransformerBlockcontainingOAIAttention+MLPBlock.Attention:
OAIAttentionwith bias on QKV and output projections, attention sinks (learnable per-head parameters), and sliding window applied on even-indexed layers only.MoE:
MLPBlockwrapsReplicatedLinearrouter (with bias) +FusedMoEwith SwiGLU activation and bias support. Customweights_mappingtranslates checkpoint names (gate_up_proj_blockstow13_weight, etc.).Normalization: RMSNorm with eps=1e-5, post-attention norm uses
x_pad_to_multiple=256.Pipeline parallelism: Supports auxiliary hidden state layers for EAGLE3 speculative decoding (
get_eagle3_aux_hidden_state_layers).
GLM4-MoE (Glm4MoeForCausalLM)
Architecture: MoE transformer with GQA, shared + routed experts, partial RoPE.
Layer structure:
Glm4MoeDecoderLayercontainingGlm4MoeAttention+ eitherGlm4MoE(MoE layers, fromfirst_k_dense_replaceonward) orGlm4MoeMLP(dense layers).Attention:
Glm4MoeAttentionwith optional QK norm (use_qk_norm), partial rotary factor of 0.5.MoE:
Glm4MoEwith sigmoid scoring,e_score_correction_bias, grouped top-k routing (n_group,topk_group), routed scaling factor. Shared experts handled separately or fused intoFusedMoEviais_rocm_aiter_fusion_shared_expert_enabled(). Expert parallelism (EP) support built in.Inherits:
Glm4MixtureOfExpertsmixin for MoE metadata management and expert load balancing (EPLB) support.
3. Weight Loading
Weight loading is handled by load_model() in atom/model_loader/loader.py.
Function Signature
def load_model(
model: nn.Module,
model_name_or_path: str,
hf_config: AutoConfig,
load_dummy: bool = False,
spec_decode: bool = False,
):
Loading Flow
SafeTensors iteration:
safetensors_weights_iterator()discovers and iterates over all*.safetensorsfiles in the model directory (or downloads them from HuggingFace Hub viadownload_weights_from_hf()). Duplicate files are filtered using themodel.safetensors.index.jsonweight map. Memory-mapped loading is used by default; setATOM_DISABLE_MMAP=trueto disable.Weight name rewriting: Each weight name goes through several transformations:
weight_scale_invis renamed toweight_scale.Model-specific
weights_mapping(e.g., GPT-OSS mapsgate_up_proj_blockstow13_weight).For speculative decoding (
spec_decode=True), MTP layer weights are rewritten viarewrite_spec_layer_name().Shared expert fusion: when enabled,
mlp.shared_expertsis remapped tomlp.experts.<n_routed_experts>so the shared expert is loaded as the last expert in theFusedMoEmodule.
Packed module resolution: The
packed_modules_mappingdict on each model class defines how HuggingFace checkpoint weight names map to ATOM’s fused parameter names. For example, Llama maps:"q_proj": ("qkv_proj", "q"), "k_proj": ("qkv_proj", "k"), "v_proj": ("qkv_proj", "v"), "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1),
Each packed parameter has a
weight_loaderattribute that knows how to shard and place the weight into the correct slice.Expert parameter loading: If the model has a
get_expert_mapping()method, expert weights are loaded usingFusedMoE.make_expert_params_mapping(), which generates (param_name, weight_name, expert_id, shard_id) tuples. This handles per-expert sharding across TP ranks.TP sharding: Parallel linear layers (
ColumnParallelLinear,RowParallelLinear,QKVParallelLinear) have customweight_loadermethods that automatically select the correct shard for the current TP rank during loading. The default fallbackdefault_weight_loaderhandles simple cases where weights need to be sliced by TP rank.Concurrent loading: All weight loading calls are submitted to a
ThreadPoolExecutorfor parallel execution.Post-processing: After all weights are loaded,
process_weights_after_loading()is called on each module (e.g., for weight pre-shuffling, scale computation), andquant_method.process_weights_after_loading()is invoked for quantized modules. ForFusedMoEMethodBase,init_prepare_finalize()is also called.
4. Adding a New Model
Follow these steps to add support for a new model architecture:
Step 1: Create the Model File
Create a new file in atom/models/, e.g., atom/models/my_model.py. Follow the existing patterns:
from atom.config import Config, QuantizationConfig
from atom.model_ops.base_attention import Attention
from atom.model_ops.embed_head import ParallelLMHead, VocabParallelEmbedding
from atom.model_ops.layernorm import RMSNorm
from atom.model_ops.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from atom.models.utils import (
IntermediateTensors,
PPMissingLayer,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
from atom.utils.decorators import support_torch_compile
Step 2: Implement Layer Classes
Each model typically defines three core module classes:
Attention module (e.g.,
MyModelAttention):Initialize
QKVParallelLinearfor query/key/value.Initialize
RowParallelLinearfor output projection.Set up rotary embeddings via
aiter.rotary_embedding.get_rope().Create
Attentionfromatom.model_ops.base_attention.
MLP module (e.g.,
MyModelMLP):Use
MergedColumnParallelLinearfor gate+up projections.Use
RowParallelLinearfor down projection.For MoE models, use
FusedMoEfromatom.model_ops.moe.
Decoder layer (e.g.,
MyModelDecoderLayer):Combine attention + MLP with RMSNorm layers.
Implement the forward pass with residual connections.
Step 3: Implement the Model and CausalLM Classes
Backbone model (e.g.,
MyModel):Decorate with
@support_torch_compile.Initialize
VocabParallelEmbedding, decoder layers viamake_layers(), and finalRMSNorm.Support pipeline parallelism with
PPMissingLayerandIntermediateTensors.
CausalLM wrapper (e.g.,
MyModelForCausalLM):Define
packed_modules_mappingto map checkpoint weight names to ATOM’s fused parameter names.Initialize the backbone model and
ParallelLMHead.Implement
forward()(returns hidden states) andcompute_logits()(returns logits vialm_head).If the model uses MoE, implement
get_expert_mapping()returningFusedMoE.make_expert_params_mapping(...).
Step 4: Register the Model
Add an entry to support_model_arch_dict in atom/model_engine/model_runner.py:
support_model_arch_dict = {
...
"MyModelForCausalLM": "atom.models.my_model.MyModelForCausalLM",
}
The key must exactly match the architectures field in the HuggingFace model’s config.json.
Step 5: Handle Weight Loading
Ensure your packed_modules_mapping correctly maps all checkpoint weight names that differ from ATOM’s internal names. Common patterns:
Checkpoint Name |
ATOM Parameter |
Shard ID |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
For MoE models, add get_expert_mapping() to delegate to FusedMoE.make_expert_params_mapping() with the correct gate/down/up projection names and expert count.
If the checkpoint uses non-standard weight names (like GPT-OSS), define a weights_mapping class attribute to rename them at load time.
5. Model-Specific Optimizations
Llama: Fused RMSNorm+Quant and SiLU+Mul+Quant
Llama supports two AITER Triton fused kernel optimizations:
ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT: Fuses the RMSNorm normalization with FP8 or MXFP4 quantization in a single kernel call. Applied to bothinput_layernormandpost_attention_layernorm. Eliminates an extra read/write pass over the hidden states.ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT: Fuses the SiLU activation, element-wise multiply, and quantization in the MLP. TheSiluAndMulmodule receives thefused_quant=Trueflag and the quant config, producing quantized output directly for the down projection.
Both are controlled by environment variables and read from atom.utils.envs.
DeepSeek V2/V3: MLA + Fused Input Norm + QK Norm Fusion
DeepSeek models use Multi-head Latent Attention (MLA) with LoRA-compressed projections (q_lora_rank, kv_lora_rank). Several fusion optimizations are available:
ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION: Fuses the input RMSNorm with quantization. Implemented via_fuse_rmsnorm_quant()which dispatches to either_fuse_rmsnorm_fp4_quant()or_fused_rms_fp8_group_quant()based on the quant dtype. When enabled, the allreduce+RMSNorm fusion is disabled forinput_layernormbut kept forpost_attention_layernorm.ATOM_ENABLE_DS_QKNORM_QUANT_FUSION: Fuses the Q/K LoRA layernorm with quantization via_fuse_qkv_a_proj_reduce_rmsnorm_quant_fp4()or the FP8 variant, which performs the fused QKV-A projection, RMSNorm on Q and KV components, and quantization in a single fused operation.ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION: Fuses tensor-parallel allreduce with RMSNorm.FP8 MQA logits:
fp8_mqa_logitsanddeepgemm_fp8_paged_mqa_logitsimplement FP8-precision attention score computation for MLA decode.FP4 support: MXFP4 quantized GEMM kernels (
gemm_afp4wfp4_preshuffle,gemm_a16wfp4_preshuffle) and FP4 block-scale BMM viais_rocm_aiter_fp4bmm_enabled().
Qwen3-MoE: QK Norm + RoPE + Cache + Quant Fusion
When ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION is enabled, the Qwen3MoeAttention module:
Precomputes a joint
cos_sin_cacheby concatenating cosine and sine RoPE caches.Passes
q_normandk_normdirectly to theAttentionmodule.The attention backend then fuses QK normalization, RoPE application, KV cache write, and optional quantization into a single kernel pass.
Additionally, ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION fuses allreduce with RMSNorm for both attention output and MoE output, reducing communication overhead.
MTP: DeepSeek Multi-Token Prediction
The DeepSeekMTP model serves as a speculative draft model:
Each
DeepSeekMultiTokenPredictorLayertakes the previous hidden state and the next token’s embedding, normalizes both (enorm,hnorm), concatenates them, and passes through a linear projection (eh_proj) followed by a standardDeepseekV2DecoderLayer.The
SharedHeadprovides per-layer norm + LM head for logit computation.For FP4 quantized main models, MTP blocks fall back to non-FP4 quantization config to maintain draft model accuracy.
Source Files
File |
Description |
|---|---|
|
Model registry ( |
|
Llama model: |
|
Qwen3 model: |
|
Qwen3-MoE model: |
|
DeepSeek V2/V3 model: |
|
DeepSeek MTP draft model: |
|
Mixtral model: |
|
GPT-OSS model: |
|
GLM4-MoE model: |
|
Model utilities: |
|
Weight loading: |
|
Weight utilities: |