Skip to main content

AI Systems Bridge · 45 min

KV Cache and Attention at Inference Time

Why token-by-token generation stores past keys and values, how KV cache memory is computed, and why MQA, GQA, paging, quantization, and offload matter.

Why This Matters

A 70B decoder with 80 layers, 64 attention heads, head dimension 128, BF16 cache entries, and a 32k-token prompt stores about 80 GiB of KV cache for one sequence. That is only the cache, not model weights, activations, CUDA workspace, or allocator fragmentation.

Generation is token-by-token. At step tt, the model computes one new query per layer, but attention still reads the keys and values from tokens 1,,t11,\ldots,t-1. Serving throughput at long context is therefore bounded by high-bandwidth memory capacity and memory traffic, not only by floating-point operations.

Core Definitions

Definition

Scaled dot-product attention

For queries QRnq×dQ \in \mathbb{R}^{n_q \times d}, keys KRnk×dK \in \mathbb{R}^{n_k \times d}, and values VRnk×dvV \in \mathbb{R}^{n_k \times d_v}, scaled dot-product attention is Attention(Q,K,V)=softmax(QKTd)V.\operatorname{Attention}(Q,K,V)=\operatorname{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V. In autoregressive decoding, nq=1n_q=1 for each generated token, while nkn_k grows with context length.

Definition

KV cache

The KV cache is the stored tensor of past keys and values for every decoder layer. During inference, the model appends the new token's key and value to the cache, then attends the new query over all cached keys and reads the weighted sum of cached values.

Definition

Multi-query and grouped-query attention

Multi-Query Attention, or MQA, shares one key head and one value head across all query heads. Grouped-Query Attention, or GQA, uses fewer key/value heads than query heads, with several query heads sharing each key/value head.

Definition

Paged attention

Paged attention stores KV blocks in fixed-size pages and uses an indirection table from logical token positions to physical cache blocks. It is analogous to virtual memory, but the pages contain key/value vectors instead of process memory bytes.

Attention During Prefill and Decode

In the prefill phase, the model consumes the prompt. If the prompt has SS tokens, each layer forms Q,K,VQ,K,V for all SS positions and computes a triangular causal attention pattern. The work is roughly quadratic in SS for attention score computation.

In the decode phase, one token is generated at a time. For the new token at position tt, the layer computes one new query qtq_t, one new key ktk_t, and one new value vtv_t. The attention output is

ot=softmax(qtK1:tTd)V1:t.o_t=\operatorname{softmax}\left(\frac{q_tK_{1:t}^T}{\sqrt{d}}\right)V_{1:t}.

The keys and values for positions 1,,t11,\ldots,t-1 are not recomputed. They were produced in previous steps and stored.

A minimal decode loop looks like this.

for (int t = prompt_len; t < max_len; ++t) {
  // x is the hidden state for the current token only.
  for (int layer = 0; layer < n_layers; ++layer) {
    q = Wq[layer] * x;
    k = Wk[layer] * x;
    v = Wv[layer] * x;

    kv_cache[layer].append(k, v);       // one position appended
    x = attend(q, kv_cache[layer]);     // reads positions 0..t
    x = mlp_and_residual(layer, x);
  }
  token = sample(lm_head(x));
}

The single-token query is small. The cache read is not. For a 32k context and 80 layers, each decode step walks over millions of cached vector elements.

Byte Layout of a KV Cache

A common layout stores keys and values separately, with contiguous head dimension entries.

// Logical shape for one layer.
// K[token][kv_head][head_dim], V[token][kv_head][head_dim]
uint16_t *K;  // BF16 stored as 16-bit payloads
uint16_t *V;

size_t offset(int token, int kv_head, int d,
              int n_kv_heads, int head_dim) {
  return ((size_t)token * n_kv_heads + kv_head) * head_dim + d;
}

For a tiny example with 3 tokens, 2 KV heads, head dimension 4, and BF16 entries, the key tensor contains 3×2×4×2=483 \times 2 \times 4 \times 2 = 48 bytes. The value tensor is another 48 bytes. The full KV cache for that layer is 96 bytes.

One physical byte layout for the key tensor is:

K base
byte 00..07  token 0, head 0, d 0..3
byte 08..15  token 0, head 1, d 0..3
byte 16..23  token 1, head 0, d 0..3
byte 24..31  token 1, head 1, d 0..3
byte 32..39  token 2, head 0, d 0..3
byte 40..47  token 2, head 1, d 0..3

If the attention kernel assigns one query head to a thread block, each block streams the key vectors for its assigned head or shared KV head, computes dot products, applies the online softmax recurrence, then streams value vectors. The cache access pattern has long sequential reads, but many independent requests in a serving batch point to different sequence lengths and different physical locations.

The KV Cache Size Formula

Let BB be batch size, SS be sequence length, NLN_L be layer count, HkvH_{kv} be the number of key/value heads, DhD_h be head dimension, and PP be bytes per cache element. The full decoder cache size is

MKV=2×B×S×NL×Hkv×Dh×P.M_{KV}=2 \times B \times S \times N_L \times H_{kv} \times D_h \times P.

The factor 2 is for keys and values. Some code paths store extra scale factors, padding, page metadata, or transposed copies, but this formula is the first capacity estimate.

For a 70B-style model with NL=80N_L=80, Hkv=64H_{kv}=64, Dh=128D_h=128, S=32768S=32768, B=1B=1, and BF16 with P=2P=2:

MKV=2×1×32768×80×64×128×2.M_{KV}=2 \times 1 \times 32768 \times 80 \times 64 \times 128 \times 2.

MKV=85,899,345,920 bytes=80 GiB.M_{KV}=85,899,345,920 \text{ bytes}=80 \text{ GiB}.

The per-layer cache is exactly 1 GiB in this configuration:

2×32768×64×128×2=1,073,741,824 bytes.2 \times 32768 \times 64 \times 128 \times 2=1,073,741,824 \text{ bytes}.

This number explains why a long-context request can consume the memory of a full accelerator even when the model weights are already quantized.

A serving capacity estimate follows from the same formula. If only 40 GiB of HBM is available for KV cache, with the same architecture and BF16 full multi-head KV, the maximum total resident tokens across all active sequences is

Stotal=40×2302×80×64×128×2=16384.S_{\text{total}}=\frac{40 \times 2^{30}}{2 \times 80 \times 64 \times 128 \times 2}=16384.

That could be one 16k-token sequence, eight 2k-token sequences, or a fragmented mix.

MQA and GQA Shrink the Cache

Full multi-head attention has Hq=HkvH_q=H_{kv}. If Hq=64H_q=64, every query head has its own key and value head.

MQA sets Hkv=1H_{kv}=1. The cache ratio relative to full KV is 1/641/64 in this example. The same 80 GiB cache becomes 1.25 GiB at 32k context, ignoring metadata.

GQA sets HkvH_{kv} to an intermediate value. If Hq=64H_q=64 and Hkv=8H_{kv}=8, each KV head serves 8 query heads. The cache ratio is 8/64=1/88/64=1/8, so the 80 GiB cache becomes 10 GiB.

The attention score computation changes only in the mapping from query head to KV head.

int kv_head_for_query(int q_head, int n_q_heads, int n_kv_heads) {
  int group_size = n_q_heads / n_kv_heads;
  return q_head / group_size;
}

// Example: 64 query heads, 8 KV heads.
// query heads 0..7 read KV head 0
// query heads 8..15 read KV head 1
// ...
// query heads 56..63 read KV head 7

MQA and GQA reduce memory capacity and memory traffic for K and V reads. They do not reduce the number of query heads, output projection size, or MLP cost. They also change model quality tradeoffs during training, so serving code cannot switch an already trained full-KV model to GQA without changing weights.

Paged Attention and Cache Allocation

A naive serving runtime allocates a contiguous KV buffer for the maximum sequence length of each request. If a request reserves 32k tokens but ends after 900 tokens, most of that reservation is wasted. If many requests grow token-by-token, contiguous allocation also causes fragmentation.

Paged attention stores fixed-size blocks. Suppose one block holds 16 tokens for one layer's K and V. With Hkv=8H_{kv}=8, Dh=128D_h=128, and BF16, one layer block is

2×16×8×128×2=65536 bytes.2 \times 16 \times 8 \times 128 \times 2=65536 \text{ bytes}.

That is 64 KiB per layer. Across 80 layers, a 16-token logical page for the whole model is 5 MiB. A 900-token request needs 900/16=57\lceil 900/16 \rceil=57 blocks per layer, not a 32k-token contiguous slab.

The page table maps logical block IDs to physical block IDs.

sequence A, logical blocks: 0  1  2  3
physical block IDs:       91 17 44 203

token 37:
logical block = 37 / 16 = 2
offset in block = 37 % 16 = 5
physical block = page_table[A][2] = 44

A CUDA kernel then adds one level of indirection.

__device__ uint16_t load_k(
    const uint16_t* K_pages,
    const int* page_table,
    int seq_id, int token, int kv_head, int d,
    int blocks_per_seq, int block_tokens,
    int n_kv_heads, int head_dim) {
  int logical_block = token / block_tokens;
  int token_in_block = token % block_tokens;
  int physical_block =
      page_table[seq_id * blocks_per_seq + logical_block];

  size_t per_block = (size_t)block_tokens * n_kv_heads * head_dim;
  size_t index = (size_t)physical_block * per_block
               + ((size_t)token_in_block * n_kv_heads + kv_head) * head_dim
               + d;
  return K_pages[index];
}

This indirection costs integer arithmetic and page table reads. It saves much more when serving many variable-length sequences, because memory is allocated near the number of tokens actually present.

Quantization and Offload

KV cache quantization stores K and V in INT8, FP8, or another compact format. If BF16 uses 2 bytes and INT8 uses 1 byte, the raw cache size is halved. The 80 GiB full-KV example becomes 40 GiB. With GQA using 8 KV heads, it becomes 5 GiB.

A simple per-head INT8 layout stores one scale per token and head.

K_q[token][kv_head][d]    int8 payload, 1 byte each
K_scale[token][kv_head]   fp16 or fp32 scale
decoded value             scale * int8_payload

For Dh=128D_h=128, INT8 payload per token per head is 128 bytes. If the scale is FP16, the scale overhead is 2 bytes, about 1.56 percent for keys and the same for values. Smaller cache entries reduce HBM capacity pressure and traffic, but dequantization adds arithmetic and can affect output distribution.

Offload moves cold KV pages to CPU memory or NVMe. The cost is latency and bandwidth. A single 16-token, 80-layer, GQA-8, BF16 page is 5 MiB as computed above. Moving 100 such pages back from CPU memory transfers about 500 MiB. That can dominate a decode step unless prefetching matches the attention access pattern.

The Model

The inference-time constraint is a capacity and bandwidth model.

MKV=2×B×S×NL×Hkv×Dh×P.M_{KV}=2 \times B \times S \times N_L \times H_{kv} \times D_h \times P.

decode bytes per token2×S×NL×Hkv×Dh×P.\text{decode bytes per token} \approx 2 \times S \times N_L \times H_{kv} \times D_h \times P.

The second formula counts reading K and V once for one generated token at batch size 1. It omits writes for the new K and V, page table loads, logits, and MLP traffic.

A Roofline view separates arithmetic intensity from peak throughput. Attention at decode has one query and many cached keys. For each key element, the kernel loads bytes and performs a small number of floating-point operations for dot products and value accumulation. As SS grows, memory traffic grows linearly. If the achieved arithmetic intensity lies below the machine balance point, HBM bandwidth bounds throughput.

Proposition

KV Cache Grows Linearly with Resident Tokens

Statement

For a fixed decoder architecture and cache precision, KV cache memory is linear in the total number of resident tokens across active requests.

Intuition

Each new token appends exactly one key vector and one value vector per layer and per KV head. No past token can be discarded while future tokens may attend to it, unless a sliding-window or retrieval policy changes the attention pattern.

Proof Sketch

For one token, one layer stores 2×Hkv×Dh2 \times H_{kv} \times D_h scalar entries. Across NLN_L layers and precision PP, that is 2×NL×Hkv×Dh×P2 \times N_L \times H_{kv} \times D_h \times P bytes per token. Multiplying by the total resident token count B×SB \times S gives the formula above.

Why It Matters

Serving admission control often reduces to bounding total resident tokens. Batch size and context length cannot be chosen independently.

Failure Mode

The statement fails for architectures that do not retain all past tokens, such as fixed sliding-window attention, recurrent memory, or external compression that replaces exact K and V vectors.

Common Confusions

Watch Out

The KV cache is not model weights

Model weights are shared across requests. KV cache is per request and grows with generated length. A 70B model with quantized weights can still run out of HBM because several long requests allocate tens of GiB of cache.

Watch Out

Prefill cost and decode cost are different

Prefill performs attention over many prompt queries at once and is compute-heavy for long prompts. Decode has one query per sequence but reads all cached keys and values at each step. The bottleneck often moves from matrix multiply throughput to memory bandwidth and capacity.

Watch Out

MQA and GQA are not runtime flags for any model

The number of KV heads is part of the trained architecture. Serving software can exploit MQA or GQA if the checkpoint has those tensors. It cannot share KV heads after training without changing the computation.

Watch Out

Paged attention does not compress vectors

Paging reduces unused reserved memory and fragmentation. It does not reduce the bytes per live token, except for small metadata overheads. GQA, MQA, quantization, or truncation reduce bytes per live token.

Exercises

ExerciseCore

Problem

Compute the BF16 KV cache size for a decoder with 32 layers, 32 KV heads, head dimension 128, batch size 4, and sequence length 4096.

ExerciseCore

Problem

A model has 64 query heads and uses GQA with 8 KV heads. For query heads 0, 7, 8, 31, and 63, give the KV head index. Then compute the cache reduction versus full multi-head KV.

ExerciseAdvanced

Problem

A paged KV cache uses 16-token blocks. A sequence has 900 live tokens. Each layer block for GQA-8, head dimension 128, BF16 is 64 KiB. The model has 80 layers. How many logical blocks are needed, and how much KV memory do they occupy across all layers?

References

Canonical:

  • Vaswani et al., Attention Is All You Need (2017), §3.2. Defines scaled dot-product and multi-head attention.
  • NVIDIA, CUDA C++ Programming Guide (CUDA 12.x), §5.3 and §8.2. Covers memory throughput and device memory access patterns.
  • Williams, Waterman, and Patterson, Roofline: An Insightful Visual Performance Model for Multicore Architectures (2009), §2-3. Gives the bandwidth versus compute ceiling model.
  • Kwon et al., Efficient Memory Management for Large Language Model Serving with PagedAttention (2023), §3-4. Introduces paged KV cache management for LLM serving.
  • Hennessy and Patterson, Computer Architecture: A Quantitative Approach (6th ed., 2017), §2.2 and §2.6. Covers memory hierarchy and bandwidth constraints.

Accessible:

  • Olah et al., A Mathematical Framework for Transformer Circuits (2021), attention heads sections.
  • Jay Alammar, The Illustrated Transformer.
  • Hugging Face, KV Cache Explained documentation.

Next Topics

  • /computationpath/attention-kernels
  • /computationpath/gpu-memory-hierarchy
  • /computationpath/llm-serving-batching
  • /computationpath/quantization-and-compression