From 3d28ad343f72c950be36aed7fb8c18ab39f14dd2 Mon Sep 17 00:00:00 2001
From: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Date: Fri, 23 May 2025 17:09:54 +0100
Subject: [PATCH] Fix figures in design doc (#18612)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
---
docs/design/kernel/paged_attention.md | 38 +++++++++------------------
1 file changed, 12 insertions(+), 26 deletions(-)
diff --git a/docs/design/kernel/paged_attention.md b/docs/design/kernel/paged_attention.md
index ad8b5c9264d24..fdfa38a29f837 100644
--- a/docs/design/kernel/paged_attention.md
+++ b/docs/design/kernel/paged_attention.md
@@ -140,22 +140,18 @@ title: vLLM Paged Attention
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
```
-
- { align="center" alt="query" width="70%" }
-
-
-
+
+ { align="center" alt="query" width="70%" }
+
- Each thread defines its own `q_ptr` which points to the assigned
query token data on global memory. For example, if `VEC_SIZE` is 4
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
total of 128 elements divided into 128 / 4 = 32 vecs.
-
- { align="center" alt="q_vecs" width="70%" }
-
-
-
+
+ { align="center" alt="q_vecs" width="70%" }
+
```cpp
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
@@ -192,11 +188,9 @@ title: vLLM Paged Attention
points to key token data based on `k_cache` at assigned block,
assigned head and assigned token.
-
- { align="center" alt="key" width="70%" }
-
-
-
+
+ { align="center" alt="key" width="70%" }
+
- The diagram above illustrates the memory layout for key data. It
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
@@ -209,11 +203,9 @@ title: vLLM Paged Attention
elements for one token) that will be processed by 2 threads (one
thread group) separately.
-
- { align="center" alt="k_vecs" width="70%" }
-
-
-
+
+ { align="center" alt="k_vecs" width="70%" }
+
```cpp
K_vec k_vecs[NUM_VECS_PER_THREAD]
@@ -372,20 +364,14 @@ title: vLLM Paged Attention
{ align="center" alt="value" width="70%" }
-
-
{ align="center" alt="logits_vec" width="50%" }
-
-
{ align="center" alt="v_vec" width="70%" }
-
-
- Now we need to retrieve the value data and perform dot multiplication