mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 12:37:14 +08:00
[Doc] Fix indentation problems in V0 Paged Attention docs (#18659)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
e77dc4bad8
commit
ef1dd6870f
@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le
|
||||
* [Deployment with GPUs](#deployment-with-gpus)
|
||||
|
||||
Alternatively, you can deploy vLLM to Kubernetes using any of the following:
|
||||
|
||||
* [Helm](frameworks/helm.md)
|
||||
* [InftyAI/llmaz](integrations/llmaz.md)
|
||||
* [KServe](integrations/kserve.md)
|
||||
|
||||
@ -3,78 +3,76 @@ title: vLLM Paged Attention
|
||||
---
|
||||
[](){ #design-paged-attention }
|
||||
|
||||
- Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (`csrc/attention/attention_kernels.cu`).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
- To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
- Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
Currently, vLLM utilizes its own implementation of a multi-head query
|
||||
attention kernel (`csrc/attention/attention_kernels.cu`).
|
||||
This kernel is designed to be compatible with
|
||||
vLLM's paged KV caches, where the key and value cache are stored in
|
||||
separate blocks (note that this block concept differs from the GPU
|
||||
thread block. So in a later document, I will refer to vLLM paged
|
||||
attention block as "block", while refer to GPU thread block as
|
||||
"thread block").
|
||||
|
||||
To achieve high performance, this kernel relies on a specially
|
||||
designed memory layout and access method, specifically when threads
|
||||
read data from global memory to shared memory. The purpose of this
|
||||
document is to provide a high-level explanation of the kernel
|
||||
implementation step by step, aiding those who wish to learn about the
|
||||
vLLM multi-head query attention kernel. After going through this
|
||||
document, users will likely have a better understanding and feel easier
|
||||
to follow the actual implementation.
|
||||
|
||||
Please note that this document may not cover all details, such as how
|
||||
to calculate the correct index for the corresponding data or the dot
|
||||
multiplication implementation. However, after reading this document
|
||||
and becoming familiar with the high-level logic flow, it should be
|
||||
easier for you to read the actual code and understand the details.
|
||||
|
||||
## Inputs
|
||||
|
||||
- The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
The kernel function takes a list of arguments for the current thread
|
||||
to perform its assigned work. The three most important arguments are
|
||||
the input pointers `q`, `k_cache`, and `v_cache`, which point
|
||||
to query, key, and value data on global memory that need to be read
|
||||
and processed. The output pointer `out` points to global memory
|
||||
where the result should be written. These four pointers actually
|
||||
refer to multi-dimensional arrays, but each thread only accesses the
|
||||
portion of data assigned to it. I have omitted all other runtime
|
||||
parameters here for simplicity.
|
||||
|
||||
```cpp
|
||||
template<
|
||||
typename scalar_t,
|
||||
int HEAD_SIZE,
|
||||
int BLOCK_SIZE,
|
||||
int NUM_THREADS,
|
||||
int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
```
|
||||
```cpp
|
||||
template<typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, int NUM_THREADS, int PARTITION_SIZE = 0>
|
||||
__device__ void paged_attention_kernel(
|
||||
... // Other side args.
|
||||
const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
... // Other side args.
|
||||
)
|
||||
```
|
||||
|
||||
- There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. `scalar_t`
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. `HEAD_SIZE` indicates the number of elements in each
|
||||
head. `BLOCK_SIZE` refers to the number of tokens in each block.
|
||||
`NUM_THREADS` denotes the number of threads in each thread block.
|
||||
`PARTITION_SIZE` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
There are also a list of template arguments above the function
|
||||
signature that are determined during compilation time. `scalar_t`
|
||||
represents the data type of the query, key, and value data elements,
|
||||
such as FP16. `HEAD_SIZE` indicates the number of elements in each
|
||||
head. `BLOCK_SIZE` refers to the number of tokens in each block.
|
||||
`NUM_THREADS` denotes the number of threads in each thread block.
|
||||
`PARTITION_SIZE` represents the number of tensor parallel GPUs (For
|
||||
simplicity, we assume this is 0 and tensor parallel is disabled).
|
||||
|
||||
- With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
With these arguments, we need to perform a sequence of preparations.
|
||||
This includes calculating the current head index, block index, and
|
||||
other necessary variables. However, for now, we can ignore these
|
||||
preparations and proceed directly to the actual calculations. It will
|
||||
be easier to understand them once we grasp the entire flow.
|
||||
|
||||
## Concepts
|
||||
|
||||
- Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
Just before we dive into the calculation flow, I want to describe a
|
||||
few concepts that are needed for later sections. However, you may
|
||||
skip this section and return later if you encounter any confusing
|
||||
terminologies.
|
||||
|
||||
- **Sequence**: A sequence represents a client request. For example,
|
||||
the data pointed to by `q` has a shape of
|
||||
`[num_seqs, num_heads, head_size]`. That represents there are total
|
||||
@ -129,236 +127,236 @@ title: vLLM Paged Attention
|
||||
|
||||
## Query
|
||||
|
||||
- This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
This section will introduce how query data is stored in memory and
|
||||
fetched by each thread. As mentioned above, each thread group fetches
|
||||
one query token data, while each thread itself only handles a part of
|
||||
one query token data. Within each warp, every thread group will fetch
|
||||
the same query token data, but will multiply it with different key
|
||||
token data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
```
|
||||
```cpp
|
||||
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="query" width="70%" }
|
||||
</figure>
|
||||
|
||||
- Each thread defines its own `q_ptr` which points to the assigned
|
||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
Each thread defines its own `q_ptr` which points to the assigned
|
||||
query token data on global memory. For example, if `VEC_SIZE` is 4
|
||||
and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains
|
||||
total of 128 elements divided into 128 / 4 = 32 vecs.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="q_vecs" width="70%" }
|
||||
</figure>
|
||||
|
||||
```cpp
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
```
|
||||
```cpp
|
||||
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
|
||||
```
|
||||
|
||||
- Next, we need to read the global memory data pointed to by `q_ptr`
|
||||
into shared memory as `q_vecs`. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
Next, we need to read the global memory data pointed to by `q_ptr`
|
||||
into shared memory as `q_vecs`. It is important to note that each
|
||||
vecs is assigned to a different row. For example, if the
|
||||
`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs,
|
||||
while thread 1 handles the 1st row vecs. By reading the query data in
|
||||
this way, neighboring threads like thread 0 and thread 1 can read
|
||||
neighbor memory, achieving the memory coalescing to improve
|
||||
performance.
|
||||
|
||||
## Key
|
||||
|
||||
- Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
Similar to the "Query" section, this section introduces memory layout
|
||||
and assignment for keys. While each thread group only handle one
|
||||
query token one kernel run, it may handle multiple key tokens across
|
||||
multiple iterations. Meanwhile, each warp will process multiple blocks
|
||||
of key tokens in multiple iterations, ensuring that all context
|
||||
tokens are processed by the entire thread group after the kernel run.
|
||||
In this context, "handle" refers to performing the dot multiplication
|
||||
between query data and key data.
|
||||
|
||||
```cpp
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
```
|
||||
```cpp
|
||||
const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
|
||||
+ kv_head_idx * kv_head_stride
|
||||
+ physical_block_offset * x;
|
||||
```
|
||||
|
||||
- Unlike to `q_ptr`, `k_ptr` in each thread will point to different
|
||||
key token at different iterations. As shown above, that `k_ptr`
|
||||
points to key token data based on `k_cache` at assigned block,
|
||||
assigned head and assigned token.
|
||||
Unlike to `q_ptr`, `k_ptr` in each thread will point to different
|
||||
key token at different iterations. As shown above, that `k_ptr`
|
||||
points to key token data based on `k_cache` at assigned block,
|
||||
assigned head and assigned token.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="key" width="70%" }
|
||||
</figure>
|
||||
|
||||
- The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||
8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
The diagram above illustrates the memory layout for key data. It
|
||||
assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is
|
||||
8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each
|
||||
rectangle represents all the elements for one key token at one head,
|
||||
which will be processed by one thread group. The left half shows the
|
||||
total 16 blocks of key token data for warp 0, while the right half
|
||||
represents the remaining key token data for other warps or
|
||||
iterations. Inside each rectangle, there are a total 32 vecs (128
|
||||
elements for one token) that will be processed by 2 threads (one
|
||||
thread group) separately.
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="k_vecs" width="70%" }
|
||||
</figure>
|
||||
|
||||
```cpp
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
```
|
||||
```cpp
|
||||
K_vec k_vecs[NUM_VECS_PER_THREAD]
|
||||
```
|
||||
|
||||
- Next, we need to read the key token data from `k_ptr` and store
|
||||
them on register memory as `k_vecs`. We use register memory for
|
||||
`k_vecs` because it will only be accessed by one thread once,
|
||||
whereas `q_vecs` will be accessed by multiple threads multiple
|
||||
times. Each `k_vecs` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
Next, we need to read the key token data from `k_ptr` and store
|
||||
them on register memory as `k_vecs`. We use register memory for
|
||||
`k_vecs` because it will only be accessed by one thread once,
|
||||
whereas `q_vecs` will be accessed by multiple threads multiple
|
||||
times. Each `k_vecs` will contain multiple vectors for later
|
||||
calculation. Each vec will be set at each inner iteration. The
|
||||
assignment of vecs allows neighboring threads in a warp to read
|
||||
neighboring memory together, which again promotes the memory
|
||||
coalescing. For instance, thread 0 will read vec 0, while thread 1
|
||||
will read vec 1. In the next inner loop, thread 0 will read vec 2,
|
||||
while thread 1 will read vec 3, and so on.
|
||||
|
||||
- You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
You may still be a little confused about the overall flow. Don't
|
||||
worry, please keep reading the next "QK" section. It will illustrate
|
||||
the query and key calculation flow in a clearer and higher-level
|
||||
manner.
|
||||
|
||||
## QK
|
||||
|
||||
- As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
`q_vecs` and each `k_vecs`.
|
||||
As shown the pseudo code below, before the entire for loop block, we
|
||||
fetch the query data for one token and store it in `q_vecs`. Then,
|
||||
in the outer for loop, we iterate through different `k_ptrs` that
|
||||
point to different tokens and prepare the `k_vecs` in the inner for
|
||||
loop. Finally, we perform the dot multiplication between the
|
||||
`q_vecs` and each `k_vecs`.
|
||||
|
||||
```cpp
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
```cpp
|
||||
q_vecs = ...
|
||||
for ... {
|
||||
k_ptr = ...
|
||||
for ... {
|
||||
k_vecs[i] = ...
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
```
|
||||
}
|
||||
...
|
||||
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
|
||||
}
|
||||
```
|
||||
|
||||
- As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the `Qk_dot<>::dot` . So `qk`
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
As mentioned before, for each thread, it only fetches part of the
|
||||
query and key token data at a time. However, there will be a cross
|
||||
thread group reduction happen in the `Qk_dot<>::dot` . So `qk`
|
||||
returned here is not just between part of the query and key token dot
|
||||
multiplication, but actually a full result between entire query and
|
||||
key token data.
|
||||
|
||||
- For example, if the value of `HEAD_SIZE` is 128 and
|
||||
`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain
|
||||
total 64 elements. However, the returned `qk` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
`Qk_dot<>::dot`. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
For example, if the value of `HEAD_SIZE` is 128 and
|
||||
`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain
|
||||
total 64 elements. However, the returned `qk` is actually the
|
||||
result of dot multiplication between 128 query elements and 128 key
|
||||
elements. If you want to learn more about the details of the dot
|
||||
multiplication and reduction, you may refer to the implementation of
|
||||
`Qk_dot<>::dot`. However, for the sake of simplicity, I will not
|
||||
cover it in this document.
|
||||
|
||||
## Softmax
|
||||
|
||||
- Next, we need to calculate the normalized softmax for all `qk`s,
|
||||
as shown above, where each $x$ represents a `qk`. To do this,
|
||||
we must obtain the reduced value of `qk_max`($m(x)$) and
|
||||
the `exp_sum`($\ell(x)$) of all `qk`s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
Next, we need to calculate the normalized softmax for all `qk`s,
|
||||
as shown above, where each $x$ represents a `qk`. To do this,
|
||||
we must obtain the reduced value of `qk_max`($m(x)$) and
|
||||
the `exp_sum`($\ell(x)$) of all `qk`s. The reduction
|
||||
should be performed across the entire thread block, encompassing
|
||||
results between the query token and all context key tokens.
|
||||
|
||||
$$
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
$$
|
||||
$$
|
||||
\begin{gather*}
|
||||
m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\
|
||||
\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
|
||||
\end{gather*}
|
||||
$$
|
||||
|
||||
### `qk_max` and `logits`
|
||||
|
||||
- Just right after we get the `qk` result, we can set the temporary
|
||||
`logits` result with `qk` (In the end, the `logits` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the `qk_max` for all `qk`s that are calculated by current
|
||||
thread group.
|
||||
Just right after we get the `qk` result, we can set the temporary
|
||||
`logits` result with `qk` (In the end, the `logits` should
|
||||
store the normalized softmax result). Also we can compare and collect
|
||||
the `qk_max` for all `qk`s that are calculated by current
|
||||
thread group.
|
||||
|
||||
```cpp
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
```
|
||||
```cpp
|
||||
if (thread_group_offset == 0) {
|
||||
const bool mask = token_idx >= context_len;
|
||||
logits[token_idx - start_token_idx] = mask ? 0.f : qk;
|
||||
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
|
||||
}
|
||||
```
|
||||
|
||||
- Please note that the `logits` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
Please note that the `logits` here is on shared memory, so each
|
||||
thread group will set the fields for its own assigned context tokens.
|
||||
Overall, the size of logits should be number of context tokens.
|
||||
|
||||
```cpp
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
```cpp
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
```
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
}
|
||||
```
|
||||
|
||||
- Then we need to get the reduced `qk_max` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max `qk` .
|
||||
Then we need to get the reduced `qk_max` across each warp. The main
|
||||
idea is to make threads in warp to communicate with each other and
|
||||
get the final max `qk` .
|
||||
|
||||
```cpp
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
```
|
||||
```cpp
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
```
|
||||
|
||||
- Finally, we can get the reduced `qk_max` from whole thread block by
|
||||
compare the `qk_max` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
Finally, we can get the reduced `qk_max` from whole thread block by
|
||||
compare the `qk_max` from all warps in this thread block. Then we
|
||||
need to broadcast the final result to each thread.
|
||||
|
||||
### `exp_sum`
|
||||
|
||||
- Similar to `qk_max`, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
Similar to `qk_max`, we need to get the reduced sum value from the
|
||||
entire thread block too.
|
||||
|
||||
```cpp
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
```
|
||||
```cpp
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
float val = __expf(logits[i] - qk_max);
|
||||
logits[i] = val;
|
||||
exp_sum += val;
|
||||
}
|
||||
...
|
||||
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
|
||||
```
|
||||
|
||||
- Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of `logits` from `qk` to `exp(qk - qk_max)`.
|
||||
Please note, the `qk_max` here is already the max `qk` across the
|
||||
whole thread block. And then we can do reduction for `exp_sum`
|
||||
across whole thread block just like the `qk_max`.
|
||||
Firstly, sum all exp values from each thread group, and meanwhile,
|
||||
convert each entry of `logits` from `qk` to `exp(qk - qk_max)`.
|
||||
Please note, the `qk_max` here is already the max `qk` across the
|
||||
whole thread block. And then we can do reduction for `exp_sum`
|
||||
across whole thread block just like the `qk_max`.
|
||||
|
||||
```cpp
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
```
|
||||
```cpp
|
||||
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
|
||||
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
|
||||
logits[i] *= inv_sum;
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain
|
||||
the final normalized softmax result as `logits`. This `logits`
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
`qk` for all assigned context tokens.
|
||||
Finally, with the reduced `qk_max` and `exp_sum`, we can obtain
|
||||
the final normalized softmax result as `logits`. This `logits`
|
||||
variable will be used for dot multiplication with the value data in
|
||||
later steps. Now, it should store the normalized softmax result of
|
||||
`qk` for all assigned context tokens.
|
||||
|
||||
## Value
|
||||
|
||||
@ -374,127 +372,127 @@ title: vLLM Paged Attention
|
||||
{ align="center" alt="v_vec" width="70%" }
|
||||
</figure>
|
||||
|
||||
- Now we need to retrieve the value data and perform dot multiplication
|
||||
with `logits`. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are `HEAD_SIZE` of
|
||||
rows and `BLOCK_SIZE` of columns that are split into multiple
|
||||
`v_vecs`.
|
||||
Now we need to retrieve the value data and perform dot multiplication
|
||||
with `logits`. Unlike query and key, there is no thread group
|
||||
concept for value data. As shown in diagram, different from key token
|
||||
memory layout, elements from the same column correspond to the same
|
||||
value token. For one block of value data, there are `HEAD_SIZE` of
|
||||
rows and `BLOCK_SIZE` of columns that are split into multiple
|
||||
`v_vecs`.
|
||||
|
||||
- Each thread always fetches `V_VEC_SIZE` elements from the same
|
||||
`V_VEC_SIZE` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple `v_vec`s from different rows and the same
|
||||
columns through multiple inner iterations. For each `v_vec`, it
|
||||
needs to be dot multiplied with the corresponding `logits_vec`,
|
||||
which is also `V_VEC_SIZE` elements from `logits`. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processed
|
||||
Each thread always fetches `V_VEC_SIZE` elements from the same
|
||||
`V_VEC_SIZE` of tokens at a time. As a result, a single thread
|
||||
retrieves multiple `v_vec`s from different rows and the same
|
||||
columns through multiple inner iterations. For each `v_vec`, it
|
||||
needs to be dot multiplied with the corresponding `logits_vec`,
|
||||
which is also `V_VEC_SIZE` elements from `logits`. Overall, with
|
||||
multiple inner iterations, each warp will process one block of value
|
||||
tokens. And with multiple outer iterations, the whole context value
|
||||
tokens are processed
|
||||
|
||||
```cpp
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
```
|
||||
```cpp
|
||||
float accs[NUM_ROWS_PER_THREAD];
|
||||
for ... { // Iteration over different blocks.
|
||||
logits_vec = ...
|
||||
for ... { // Iteration over different rows.
|
||||
v_vec = ...
|
||||
...
|
||||
accs[i] += dot(logits_vec, v_vec);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- As shown in the above pseudo code, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
`v_vec` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in `accs`. Therefore, each entry of `accs` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
As shown in the above pseudo code, in the outer loop, similar to
|
||||
`k_ptr`, `logits_vec` iterates over different blocks and reads
|
||||
`V_VEC_SIZE` elements from `logits`. In the inner loop, each
|
||||
thread reads `V_VEC_SIZE` elements from the same tokens as a
|
||||
`v_vec` and performs dot multiplication. It is important to note
|
||||
that in each inner iteration, the thread fetches different head
|
||||
position elements for the same tokens. The dot result is then
|
||||
accumulated in `accs`. Therefore, each entry of `accs` is mapped
|
||||
to a head position assigned to the current thread.
|
||||
|
||||
- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If `HEAD_SIZE`
|
||||
is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to
|
||||
fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are
|
||||
a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each
|
||||
thread fetches 8 value elements for 8 tokens at a time. Each element
|
||||
is from different tokens at the same head position. If `HEAD_SIZE`
|
||||
is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to
|
||||
fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are
|
||||
a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle
|
||||
a whole block of value tokens. And each `accs` in each thread
|
||||
contains 8 elements that accumulated at 8 different head positions.
|
||||
For the thread 0, the `accs` variable will have 8 elements, which
|
||||
are 0th, 32th … 224th elements of a value head that are accumulated
|
||||
from all assigned 8 tokens.
|
||||
|
||||
## LV
|
||||
|
||||
- Now, we need to perform reduction for `accs` within each warp. This
|
||||
process allows each thread to accumulate the `accs` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
Now, we need to perform reduction for `accs` within each warp. This
|
||||
process allows each thread to accumulate the `accs` for the
|
||||
assigned head positions of all tokens in one block.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
float acc = accs[i];
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
```
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
```
|
||||
|
||||
- Next, we perform reduction for `accs` across all warps, allowing
|
||||
each thread to have the accumulation of `accs` for the assigned
|
||||
head positions of all context tokens. Please note that each `accs`
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
Next, we perform reduction for `accs` across all warps, allowing
|
||||
each thread to have the accumulation of `accs` for the assigned
|
||||
head positions of all context tokens. Please note that each `accs`
|
||||
in every thread only stores the accumulation for a portion of
|
||||
elements of the entire head for all context tokens. However, overall,
|
||||
all results for output have been calculated but are just stored in
|
||||
different thread register memory.
|
||||
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
```cpp
|
||||
float* out_smem = reinterpret_cast<float*>(shared_mem);
|
||||
for (int i = NUM_WARPS; i > 1; i /= 2) {
|
||||
// Upper warps write to shared memory.
|
||||
...
|
||||
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
dst[row_idx] = accs[i];
|
||||
}
|
||||
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
// Lower warps update the output.
|
||||
const float* src = &out_smem[warp_idx * HEAD_SIZE];
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
...
|
||||
accs[i] += src[row_idx];
|
||||
}
|
||||
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
// Write out the accs.
|
||||
}
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
- Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
Now we can write all of calculated result from local register memory
|
||||
to final output global memory.
|
||||
|
||||
```cpp
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
```
|
||||
```cpp
|
||||
scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
|
||||
+ head_idx * max_num_partitions * HEAD_SIZE
|
||||
+ partition_idx * HEAD_SIZE;
|
||||
```
|
||||
|
||||
- First, we need to define the `out_ptr` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
First, we need to define the `out_ptr` variable, which points to
|
||||
the start address of the assigned sequence and assigned head.
|
||||
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
```
|
||||
```cpp
|
||||
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
|
||||
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
|
||||
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
|
||||
from_float(*(out_ptr + row_idx), accs[i]);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
- Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
`out_ptr`.
|
||||
Finally, we need to iterate over different assigned head positions
|
||||
and write out the corresponding accumulated result based on the
|
||||
`out_ptr`.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user