[Doc] Clarify FP8 KV cache computation workflow (#31071)

Signed-off-by: westers <steve.westerhouse@origami-analytics.com>
This commit is contained in:
Steve Westerhouse 2025-12-21 18:41:37 -06:00 committed by GitHub
parent 06d490282f
commit 9d701e90d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 21 deletions

View File

@ -139,18 +139,18 @@ token data.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
```
<figure markdown="span">
![](../assets/design/paged_attention/query.png){ align="center" alt="query" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/query.png" alt="query" width="70%" />
</p>
Each thread defines its own `q_ptr` which points to the assigned
query token data on global memory. For example, if `VEC_SIZE` is 4
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs.
<figure markdown="span">
![](../assets/design/paged_attention/q_vecs.png){ align="center" alt="q_vecs" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/q_vecs.png" alt="q_vecs" width="70%" />
</p>
```cpp
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
@ -187,9 +187,9 @@ key token at different iterations. As shown above, that `k_ptr`
points to key token data based on `k_cache` at assigned block,
assigned head and assigned token.
<figure markdown="span">
![](../assets/design/paged_attention/key.png){ align="center" alt="key" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/key.png" alt="key" width="70%" />
</p>
The diagram above illustrates the memory layout for key data. It
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
@ -202,9 +202,9 @@ iterations. Inside each rectangle, there are a total 32 vecs (128
elements for one token) that will be processed by 2 threads (one
thread group) separately.
<figure markdown="span">
![](../assets/design/paged_attention/k_vecs.png){ align="center" alt="k_vecs" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/k_vecs.png" alt="k_vecs" width="70%" />
</p>
```cpp
K_vec k_vecs[NUM_VECS_PER_THREAD]
@ -361,17 +361,17 @@ later steps. Now, it should store the normalized softmax result of
## Value
<figure markdown="span">
![](../assets/design/paged_attention/value.png){ align="center" alt="value" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/value.png" alt="value" width="70%" />
</p>
<figure markdown="span">
![](../assets/design/paged_attention/logits_vec.png){ align="center" alt="logits_vec" width="50%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/logits_vec.png" alt="logits_vec" width="50%" />
</p>
<figure markdown="span">
![](../assets/design/paged_attention/v_vec.png){ align="center" alt="v_vec" width="70%" }
</figure>
<p align="center">
<img src="../assets/design/paged_attention/v_vec.png" alt="v_vec" width="70%" />
</p>
Now we need to retrieve the value data and perform dot multiplication
with `logits`. Unlike query and key, there is no thread group

View File

@ -17,6 +17,16 @@ The E4M3 format offers higher precision compared to E5M2. However, due to its sm
For now, only per-tensor (scalar) scaling factors are supported. Development is ongoing to support scaling factors of a finer granularity (e.g. per-channel).
### How FP8 KV Cache Works
The FP8 KV cache implementation follows this workflow:
1. **Storage**: Key and Value tensors are quantized to FP8 format using scaling factors before being stored in the KV cache
2. **Retrieval**: When needed for attention computation, cached KV tensors are dequantized back to higher precision (FP16/BF16)
3. **Attention**: The attention-value multiplication (softmax output × V) is performed using the dequantized higher-precision V tensor
This means the final attention computation operates on dequantized values, not FP8 tensors. The quantization reduces memory usage during storage but maintains computation accuracy by using higher precision during the actual attention operations.
### Performance Impact
The current FP8 KV cache implementation primarily benefits throughput by allowing approximately double the amount of space for KV cache allocation. This enables either: