Attention Operations

AITER provides highly optimized attention kernels for AMD GPUs with ROCm.

Flash Attention

Standard flash attention implementation with optional causal masking.

Parameters:

  • query (torch.Tensor) - Query tensor of shape (batch, seq_len, num_heads, head_dim)

  • key (torch.Tensor) - Key tensor of shape (batch, seq_len, num_heads, head_dim)

  • value (torch.Tensor) - Value tensor of shape (batch, seq_len, num_heads, head_dim)

  • causal (bool, optional) - Whether to apply causal masking. Default: False

  • softmax_scale (float, optional) - Scaling factor for softmax. Default: 1/sqrt(head_dim)

Returns:

  • output (torch.Tensor) - Attention output of shape (batch, seq_len, num_heads, head_dim)

Example:

import torch
import aiter

q = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16)

output = aiter.flash_attn_func(q, k, v, causal=True)

Flash Attention with KV Cache

Optimized attention with paged KV cache support for inference.

Parameters:

  • query (torch.Tensor) - Query tensor (batch, seq_len, num_heads, head_dim)

  • kv_cache (torch.Tensor) - Paged KV cache (num_blocks, num_heads, block_size, head_dim)

  • page_table (torch.Tensor) - Page table mapping (batch, max_blocks_per_seq)

  • block_size (int) - Size of each page block (e.g., 128)

  • causal (bool, optional) - Causal masking. Default: True

Returns:

  • output (torch.Tensor) - Attention output (batch, seq_len, num_heads, head_dim)

Example:

query = torch.randn(4, 128, 16, 64, device='cuda', dtype=torch.float16)
kv_cache = torch.randn(256, 16, 128, 64, device='cuda', dtype=torch.float16)
page_table = torch.randint(0, 256, (4, 32), device='cuda', dtype=torch.int32)

output = aiter.flash_attn_with_kvcache(
    query, kv_cache, page_table, block_size=128
)

Grouped Query Attention (GQA)

Efficient grouped query attention for models like Llama 2.

Parameters:

  • query (torch.Tensor) - (batch, seq_len, num_q_heads, head_dim)

  • key (torch.Tensor) - (batch, seq_len, num_kv_heads, head_dim)

  • value (torch.Tensor) - (batch, seq_len, num_kv_heads, head_dim)

  • num_groups (int) - Number of query heads per KV head

  • causal (bool, optional) - Causal masking. Default: False

Returns:

  • output (torch.Tensor) - (batch, seq_len, num_q_heads, head_dim)

Multi-Query Attention (MQA)

Multi-query attention where all query heads share single key/value heads.

Parameters:

  • query (torch.Tensor) - (batch, seq_len, num_heads, head_dim)

  • key (torch.Tensor) - (batch, seq_len, 1, head_dim)

  • value (torch.Tensor) - (batch, seq_len, 1, head_dim)

  • causal (bool, optional) - Causal masking. Default: False

Returns:

  • output (torch.Tensor) - (batch, seq_len, num_heads, head_dim)

Variable Sequence Attention

Attention with variable-length sequences using page tables.

Parameters:

  • query (torch.Tensor) - Query tensor

  • key (torch.Tensor) - Key tensor

  • value (torch.Tensor) - Value tensor

  • seq_lengths (torch.Tensor) - Actual sequence lengths (batch,)

  • max_seq_len (int) - Maximum sequence length

Returns:

  • output (torch.Tensor) - Attention output

Supported Architectures

AITER attention kernels are optimized for:

  • AMD Instinct MI300X (gfx942) - Best performance

  • AMD Instinct MI250X (gfx90a) - Fully supported

  • AMD Instinct MI300A (gfx950) - Experimental

Performance Characteristics

Operation

Typical Speedup

Memory Efficient

Best For

flash_attn_func

2-4x vs PyTorch

Yes

Training & Inference

flash_attn_with_kvcache

3-6x vs naive

Yes

LLM Inference

grouped_query_attention

2-3x vs unfused

Moderate

Llama-style models

variable_length_attention

4-8x vs padded

High

Variable batches

See Also

  • ../tutorials/attention - Attention tutorial

  • ../tutorials/variable_length - Variable-length sequences

  • ../benchmarks - Performance benchmarks