diff --git a/docs/design/kernel/paged_attention.md b/docs/design/kernel/paged_attention.md index ad8b5c9264d24..fdfa38a29f837 100644 --- a/docs/design/kernel/paged_attention.md +++ b/docs/design/kernel/paged_attention.md @@ -140,22 +140,18 @@ title: vLLM Paged Attention const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; ``` -
- ![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" } -
-
-
+
+ ![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" } +
- Each thread defines its own `q_ptr` which points to the assigned query token data on global memory. For example, if `VEC_SIZE` is 4 and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains total of 128 elements divided into 128 / 4 = 32 vecs. -
- ![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" } -
-
-
+
+ ![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" } +
```cpp __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; @@ -192,11 +188,9 @@ title: vLLM Paged Attention points to key token data based on `k_cache` at assigned block, assigned head and assigned token. -
- ![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" } -
-
-
+
+ ![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" } +
- The diagram above illustrates the memory layout for key data. It assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is @@ -209,11 +203,9 @@ title: vLLM Paged Attention elements for one token) that will be processed by 2 threads (one thread group) separately. -
- ![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" } -
-
-
+
+ ![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" } +
```cpp K_vec k_vecs[NUM_VECS_PER_THREAD] @@ -372,20 +364,14 @@ title: vLLM Paged Attention
![](../../assets/kernel/value.png){ align="center" alt="value" width="70%" } -
-
![](../../assets/kernel/logits_vec.png){ align="center" alt="logits_vec" width="50%" } -
-
![](../../assets/kernel/v_vec.png){ align="center" alt="v_vec" width="70%" } -
-
- Now we need to retrieve the value data and perform dot multiplication