ATOM Scheduling & KV Cache Guide
ATOM (AiTer Optimized Model) uses a prefill-first scheduler with paged KV cache block management to drive LLM inference on AMD ROCm/HIP GPUs. This guide covers the scheduling algorithm, batch construction, block-level KV cache management, prefix caching, postprocessing, speculative decoding integration, and sequence lifecycle.
Quick Reference
Class |
File |
Purpose |
|---|---|---|
|
|
Orchestrates prefill/decode scheduling, preemption, and postprocessing |
|
|
Immutable snapshot of a scheduled batch sent to the model runner |
|
|
Holds sampled token IDs and draft token IDs returned from forward pass |
|
|
Manages paged KV cache blocks with allocation, deallocation, and prefix caching |
|
|
Single KV cache block with ID, reference count, hash, and token IDs |
|
|
Tracks a single request through its lifetime (tokens, blocks, status, timing) |
|
|
Enum: |
|
|
Enum: |
|
|
Dataclass streamed to clients with new tokens and finish status |
|
|
Scheduling-related fields: |
Key config defaults:
Field |
Default |
Description |
|---|---|---|
|
512 |
Maximum sequences in a single batch |
|
16384 |
Maximum tokens scheduled in a single step |
|
16 |
Tokens per KV cache block (must be multiple of 16, or 1) |
|
|
Enable hash-based prefix block sharing |
|
0.0 |
Delay factor for batching prompt requests (0 = no delay) |
|
0.9 |
Fraction of GPU memory for KV cache |
1. Scheduling Algorithm
The scheduler implements a prefill-first policy: all waiting (prefill) requests are scheduled before any running (decode) requests. The entry point is Scheduler.schedule(), which returns a (ScheduledBatch, dict[int, Sequence]) tuple or None if both queues are empty.
1.1 Scheduler Initialization
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.stop_token_ids = config.stop_token_ids
self.block_manager = BlockManager(config)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
self.prev_time = 0.0
self.prev_prompt = False
self.last_prompt_latency = 0.0
self.delay_factor = config.scheduler_delay_factor
self.use_spec = config.speculative_config is not None
self.mtp_k: int = (
config.speculative_config.num_speculative_tokens if self.use_spec else 0
)
self.total_draft_tokens = 0
self.total_accepted_tokens = 0
The scheduler maintains two deques – waiting (pending prefill) and running (active decode) – plus a BlockManager for KV cache allocation.
1.2 Schedule Flow
Scheduler.schedule() proceeds in two phases:
Phase 1 – Prefill scheduling:
While the delay gate passes (
_passed_delay), the waiting queue is non-empty, andnum_seqs_prefill < max_num_seqs:Peek the first waiting sequence.
Compute
num_new_tokens = seq.num_tokens - seq.num_cached_tokens(prefix cache hits reduce new tokens).If
num_batched_tokens + num_new_tokens > max_num_batched_tokensorblock_manager.can_allocate(seq)returnsFalse, break.Otherwise: allocate blocks, set
seq.status = RUNNING,seq.type = PREFILL, move fromwaitingtorunning.
If any prefill sequences were scheduled, return the batch immediately (no decode mixing).
Phase 2 – Decode scheduling (only when zero prefills were scheduled):
Pop sequences from
runningup tomax_num_seqs.For each sequence, check
block_manager.can_append(seq).If a block cannot be appended, preempt the last running sequence (move it back to
waitingwith statusWAITINGand deallocate its blocks).If the sequence has speculative draft tokens (
seq.spec_token_ids), record them inscheduled_spec_decode_tokens.Call
block_manager.may_append(seq, num_new_tokens)wherenum_new_tokens = mtp_k + 1.Re-insert all scheduled sequences back into
running(preserving order).
1.3 Delay Factor
When scheduler_delay_factor > 0, the scheduler delays prefill scheduling to allow the waiting queue to accumulate more requests for better batching:
def _passed_delay(self, now: float) -> bool:
if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time
self.prev_time, self.prev_prompt = now, False
if self.delay_factor > 0 and self.waiting:
earliest_arrival_time = min([seq.arrive_time for seq in self.waiting])
passed_delay = (now - earliest_arrival_time) > (
self.delay_factor * self.last_prompt_latency
) or not self.running
else:
passed_delay = True
return passed_delay
A new prefill is scheduled only when the earliest waiting request has waited longer than delay_factor * last_prompt_latency, or when there are no running decode requests.
1.4 Preemption
When a decode step cannot extend a sequence’s KV cache (no free blocks), the scheduler preempts the last running sequence:
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)
The preempted sequence is pushed to the front of the waiting queue and its blocks are fully deallocated, so it will be re-prefilled on the next scheduling cycle.
2. ScheduledBatch Structure
ScheduledBatch is constructed by Scheduler.schedule() and passed to the model runner. It is a frozen snapshot of batch metadata.
2.1 Constructor Signature
class ScheduledBatch:
def __init__(
self,
seqs: dict[int, Sequence],
num_scheduled_tokens: list[int],
total_tokens_num: int,
total_tokens_num_prefill: int = 0,
total_tokens_num_decode: int = 0,
total_seqs_num: int = 0,
total_seqs_num_prefill: int = 0,
total_seqs_num_decode: int = 0,
is_dummy_run: bool = False,
num_spec_step: int = 0,
scheduled_spec_decode_tokens: dict[int, list[int]] = {},
):
2.2 Fields
Field |
Type |
Description |
|---|---|---|
|
|
Sequence IDs in batch order ( |
|
|
Last |
|
|
Sampling temperature per sequence |
|
|
Total token count per sequence ( |
|
|
Block ID tables for sequences that have block tables |
|
|
Number of valid tokens in each sequence’s last block |
|
|
Number of tokens served from prefix cache per sequence |
|
|
Number of new tokens scheduled per sequence |
|
|
Sum of all scheduled tokens across all sequences |
|
|
Total scheduled tokens for prefill sequences |
|
|
Total scheduled tokens for decode sequences |
|
|
Total number of sequences in the batch |
|
|
Number of prefill sequences |
|
|
Number of decode sequences |
|
|
Whether this is a dummy/warmup run |
|
|
Number of speculative decode steps ( |
|
|
Draft token IDs per sequence ID from prior speculative step |
2.3 ScheduledBatchOutput
Returned by the model runner after a forward pass:
class ScheduledBatchOutput:
def __init__(
self,
token_ids: dict[int, tuple[int, ...]],
draft_token_ids,
):
self.req_ids = list(token_ids.keys())
self.token_ids = token_ids # {seq_id: (accepted_token_ids...)}
self.draft_token_ids = draft_token_ids # {seq_id: [draft_ids]} or None
token_idsmaps sequence ID to a tuple of accepted token IDs.draft_token_idsmaps sequence ID to a list of speculative draft token IDs for the next step (when MTP is active).A special key
-1intoken_idssignals deferred output mode.
3. Block Manager
The BlockManager implements paged KV cache management with fixed-size blocks.
3.1 Block Class
class Block:
def __init__(self, block_id):
self.block_id = block_id # Unique integer ID
self.ref_count = 0 # Number of sequences referencing this block
self.hash = -1 # xxhash64 digest for prefix caching (-1 = unhashed)
self.token_ids = [] # Token IDs stored in this block
Methods:
update(hash, token_ids)– Sets the block’s hash and token content.reset()– Setsref_count = 1,hash = -1,token_ids = [](used on fresh allocation).
3.2 BlockManager Initialization
class BlockManager:
def __init__(self, config: Config):
block_size = config.kv_cache_block_size # Tokens per block (default 16)
num_blocks = config.num_kvcache_blocks # Total blocks in pool
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_caching
The block pool is pre-allocated at startup. free_block_ids is a deque for O(1) pop/push, used_block_ids tracks active blocks, and hash_to_block_id maps content hashes to block IDs for prefix caching.
3.3 Allocation (allocate)
Called during prefill scheduling for new sequences:
def allocate(self, seq: Sequence):
Iterates over
seq.num_blocksblocks.For each block, computes hash if the block is full (
len(token_ids) == block_size). Partial (last) blocks gethash = -1.If prefix caching is enabled, looks up
hash_to_block_id:Cache hit: Verifies
token_idsmatch. If the block is already inused_block_ids, incrementsref_count. If it was evicted but still in the free list, re-allocates it. Incrementsseq.num_cached_tokensbyblock_size.Cache miss: Allocates from
free_block_ids[0].
Full blocks are registered in
hash_to_block_id.
3.4 Deallocation (deallocate)
Called when a sequence finishes or is preempted:
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()
Blocks are released in reverse order. Shared blocks (with ref_count > 1 from prefix caching) are not freed until all referencing sequences release them.
3.5 Can-Allocate and Can-Append Checks
def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
can_allocatechecks that enough free blocks exist for the full sequence.can_appendchecks whether a decode step needs a new block. A new block is needed only whenlen(seq) % block_size == 1(the previous block just filled up), requiring exactly 1 free block.
3.6 May-Append (Decode Extension)
def may_append(self, seq: Sequence, num_new_tokens: int = 1):
Called during decode scheduling to extend a sequence’s block table:
If the sequence length modulo
block_sizefalls within(0, num_new_tokens], orblock_size == 1, a new block is needed:Allocates from
free_block_idsand appends toblock_table.For
block_size == 1, immediately computes and stores the hash.
If
seq_len % block_size == 0, the last block is now full – computes and stores its hash using the chained prefix.Otherwise the last block is partially filled with
hash = -1(hash deferred until full).
4. Prefix Caching
Prefix caching enables sharing KV cache blocks across sequences that share a common prompt prefix, avoiding redundant computation.
4.1 Hash Function
ATOM uses xxhash64 (via the xxhash Python library) for fast, collision-resistant block hashing:
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
4.2 Hash Chaining
Blocks form a hash chain: each block’s hash incorporates the previous block’s hash as a prefix. This ensures that two blocks with identical token content but different preceding context produce different hashes.
First block:
compute_hash(token_ids, prefix=-1)(no prefix).Subsequent blocks:
compute_hash(token_ids, prefix=prev_block.hash).Only full blocks (where
len(token_ids) == block_size) receive a hash. Partial blocks havehash = -1and are not cached.
4.3 Cache Lookup During Allocation
During allocate(), for each full block:
Compute the block hash via the chain.
Look up
hash_to_block_id.get(h, -1).If found, verify
self.blocks[block_id].token_ids == token_ids(guard against hash collisions).Hit: Reuse the block. If already in
used_block_ids, incrementref_count. Addblock_sizetoseq.num_cached_tokens.Miss (or first miss in chain): Once a cache miss occurs, all subsequent blocks in the sequence are also misses (
cache_miss = Trueis sticky). Allocate fresh blocks from the free list.
4.4 Reference Counting
On allocation:
block.reset()setsref_count = 1.On cache hit for an in-use block:
ref_count += 1.On deallocation:
ref_count -= 1. Block returns to free list only whenref_count == 0.Shared blocks (prefix cache hits) have
ref_count > 1.
4.5 Enabling Prefix Caching
Set enable_prefix_caching=True in Config. When disabled, the hash lookup in allocate() is skipped entirely (block_id is always -1).
5. Postprocessing
Scheduler.postprocess() is called after the model forward pass to update sequences with sampled tokens, check stop conditions, generate streaming output, and clean up finished sequences.
5.1 Signature
def postprocess(
self,
seqs: list[Sequence],
fwd_output: ScheduledBatchOutput,
stream_output_queue=None,
) -> list[Sequence]:
5.2 Token Appending
For each running sequence whose ID appears in fwd_output.req_ids:
Deferred output or speculative decode with EOS: Replaces placeholder tokens in-place:
seq.token_ids[-num_placeholder:] = token_ids seq.output_tokens[-num_placeholder:] = token_ids
Normal path: Calls
seq.append_token(token_id)for each accepted token, which appends totoken_ids, updatesoutput_tokens,last_token, andnum_tokens.
5.3 Stop Condition Checking
The postprocessor checks stop conditions in priority order:
Stop token sequences: Compares the tail of
seq.token_idsagainst each entry inseq.stop_token_sequences. Also checks the MTP-adjusted position for speculative decode. Setsleave_reason = "stop_sequence".EOS token: If
self.eos_token_idappears in the accepted tokens andseq.ignore_eosisFalse. Setsleave_reason = "eos".Stop token IDs: If any accepted token is in
self.stop_token_ids(fromConfig.stop_token_ids, derived from the model’s generation config). Setsleave_reason = "stop_{token_id}".Max tokens: If
seq.num_completion_tokens >= seq.max_tokens. Setsleave_reason = "max_tokens".
5.4 Stream Output
When stream_output_queue is provided, the scheduler creates a RequestOutput for each processed sequence:
request_output = RequestOutput(
request_id=seq.id,
output_tokens=output_tokens_list,
finished=(leave_reason is not None),
finish_reason=leave_reason,
)
RequestOutput fields:
Field |
Type |
Description |
|---|---|---|
|
|
Sequence ID |
|
|
Newly generated tokens since last callback |
|
|
Whether the sequence is done |
|
|
One of: |
Stream outputs are batched and put onto stream_output_queue via put_nowait.
5.5 Sequence Cleanup
For finished sequences:
Set
seq.status = SequenceStatus.FINISHED.Call
block_manager.deallocate(seq)to free KV cache blocks.Remove from the
runningdeque.Return in the
finished_seqslist.
5.6 Placeholder Insertion
When speculative decoding or deferred output is active, placeholder EOS tokens are appended to still-running sequences to reserve KV cache slots for the next step:
if need_placeholder:
for seq in seqs:
if seq.status == SequenceStatus.RUNNING:
for _ in range(seq.num_placeholder):
seq.append_token(self.eos_token_id)
The placeholder count is determined as follows:
For sequences processed in this step (had output in
fwd_output): always1 + mtp_k, regardless of mode.For sequences not processed (skipped in this step): the count depends on the batch-level mode:
Deferred output + speculative:
mtp_k + 1Deferred output only:
1Speculative only:
mtp_k
6. Speculative Decoding Integration
ATOM supports Multi-Token Prediction (MTP) speculative decoding, where a draft model proposes mtp_k additional tokens per step.
6.1 Scheduler Tracking
self.use_spec = config.speculative_config is not None
self.mtp_k: int = config.speculative_config.num_speculative_tokens if self.use_spec else 0
self.total_draft_tokens = 0
self.total_accepted_tokens = 0
Note: SpeculativeConfig currently enforces num_speculative_tokens == 1.
6.2 Draft Tokens in Scheduling
During decode scheduling:
If
seq.spec_token_idsis non-empty, the draft tokens are recorded inscheduled_spec_decode_tokens[seq.id].num_new_tokens = mtp_k + 1(1 target +mtp_kdraft tokens), somay_appendreserves enough block space.The
ScheduledBatchcarriesnum_spec_step = mtp_kand thescheduled_spec_decode_tokensdict.
6.3 Acceptance Statistics
def update_spec_stats(self, num_accepted_tokens):
self.total_draft_tokens += self.mtp_k
self.total_accepted_tokens += num_accepted_tokens - self.mtp_k
Every 1000 draft tokens, the acceptance rate is logged:
[MTP Stats] Total draft tokens: 5000, Accepted: 3750, Acceptance rate: 75.00%
6.4 Draft Token Storage on Sequences
After postprocessing, accepted draft token IDs for the next step are stored on the sequence:
if draft_token_ids and seq.id in draft_token_ids:
seq.spec_token_ids = draft_token_ids[seq.id]
These are picked up by the scheduler on the next schedule() call.
7. Sequence Management
The Sequence class represents a single request throughout its lifecycle.
7.1 Constructor
class Sequence:
def __init__(
self,
token_ids: list[int],
block_size: int,
sampling_params=SamplingParams(),
stop_token_sequences: list[list[int]] = None,
stream_callback: Optional[Callable[[Any], None]] = None,
id=None,
):
7.2 Core Fields
Field |
Type |
Description |
|---|---|---|
|
|
Auto-incrementing unique ID (from |
|
|
Full token sequence (prompt + completion) |
|
|
KV cache block size (from config) |
|
|
Current lifecycle state |
|
|
Current step type ( |
|
|
Total tokens (prompt + completion); property with setter that also updates |
|
|
Number of prompt tokens (fixed at init) |
|
|
Tokens served from prefix cache |
|
|
Ordered list of block IDs assigned to this sequence |
|
|
Most recently appended token ID |
|
|
Sampling temperature (from |
|
|
Max completion tokens (from |
|
|
Whether to ignore EOS tokens (from |
|
|
Stop strings (from |
|
|
Token-level stop sequences |
|
|
Per-sequence stream callback |
|
|
Cache of newly generated tokens |
|
|
Speculative draft token IDs for next step |
|
|
Number of placeholder tokens inserted for speculative/deferred output |
7.3 Timing Fields
Field |
Type |
Description |
|---|---|---|
|
|
Timestamp when the sequence entered the scheduler |
|
|
Timestamp of the first completion token (TTFT measurement) |
|
|
Timestamp when the sequence finished |
|
|
Reason for finishing (e.g., |
7.4 Computed Properties
Property |
Returns |
|---|---|
|
|
|
|
|
|
|
|
|
|
7.5 num_tokens Setter
Setting num_tokens triggers derived field updates:
@num_tokens.setter
def num_tokens(self, value):
self._num_tokens = value
self.num_blocks = (value + self.block_size - 1) // self.block_size
self.last_block_num_tokens = self._num_tokens - (self.num_blocks - 1) * self.block_size
7.6 Lifecycle
allocate blocks
add(seq) ---------> WAITING ---------> RUNNING (PREFILL)
^ |
| | next schedule() step
preempt() v
| RUNNING (DECODE) <--+
+--- can't append | |
| stop condition met
v
FINISHED
|
| deallocate blocks
v
(removed from running)
7.7 SequenceStatus Enum
Value |
Meaning |
|---|---|
|
In the waiting queue, pending prefill |
|
Actively being processed (prefill or decode) |
|
Stop condition met, blocks deallocated |
|
Sentinel for engine shutdown |
7.8 SequenceType Enum
Value |
Meaning |
|---|---|
|
Initial state before scheduling |
|
Currently in prefill phase |
|
Currently in decode phase |
Source Files
File |
Description |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|