From ef1dd6870f848c5814528a81b71bc87ba317e63f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 24 May 2025 21:06:35 +0800 Subject: [PATCH] [Doc] Fix indentation problems in V0 Paged Attention docs (#18659) Signed-off-by: DarkLight1337 --- docs/deployment/k8s.md | 1 + docs/design/kernel/paged_attention.md | 678 +++++++++++++------------- 2 files changed, 339 insertions(+), 340 deletions(-) diff --git a/docs/deployment/k8s.md b/docs/deployment/k8s.md index bd2bd44cd5225..6b08c4960d028 100644 --- a/docs/deployment/k8s.md +++ b/docs/deployment/k8s.md @@ -9,6 +9,7 @@ Deploying vLLM on Kubernetes is a scalable and efficient way to serve machine le * [Deployment with GPUs](#deployment-with-gpus) Alternatively, you can deploy vLLM to Kubernetes using any of the following: + * [Helm](frameworks/helm.md) * [InftyAI/llmaz](integrations/llmaz.md) * [KServe](integrations/kserve.md) diff --git a/docs/design/kernel/paged_attention.md b/docs/design/kernel/paged_attention.md index fdfa38a29f837..6ebe1ee48acf1 100644 --- a/docs/design/kernel/paged_attention.md +++ b/docs/design/kernel/paged_attention.md @@ -3,78 +3,76 @@ title: vLLM Paged Attention --- [](){ #design-paged-attention } -- Currently, vLLM utilizes its own implementation of a multi-head query - attention kernel (`csrc/attention/attention_kernels.cu`). - This kernel is designed to be compatible with - vLLM's paged KV caches, where the key and value cache are stored in - separate blocks (note that this block concept differs from the GPU - thread block. So in a later document, I will refer to vLLM paged - attention block as "block", while refer to GPU thread block as - "thread block"). -- To achieve high performance, this kernel relies on a specially - designed memory layout and access method, specifically when threads - read data from global memory to shared memory. The purpose of this - document is to provide a high-level explanation of the kernel - implementation step by step, aiding those who wish to learn about the - vLLM multi-head query attention kernel. After going through this - document, users will likely have a better understanding and feel easier - to follow the actual implementation. -- Please note that this document may not cover all details, such as how - to calculate the correct index for the corresponding data or the dot - multiplication implementation. However, after reading this document - and becoming familiar with the high-level logic flow, it should be - easier for you to read the actual code and understand the details. +Currently, vLLM utilizes its own implementation of a multi-head query +attention kernel (`csrc/attention/attention_kernels.cu`). +This kernel is designed to be compatible with +vLLM's paged KV caches, where the key and value cache are stored in +separate blocks (note that this block concept differs from the GPU +thread block. So in a later document, I will refer to vLLM paged +attention block as "block", while refer to GPU thread block as +"thread block"). + +To achieve high performance, this kernel relies on a specially +designed memory layout and access method, specifically when threads +read data from global memory to shared memory. The purpose of this +document is to provide a high-level explanation of the kernel +implementation step by step, aiding those who wish to learn about the +vLLM multi-head query attention kernel. After going through this +document, users will likely have a better understanding and feel easier +to follow the actual implementation. + +Please note that this document may not cover all details, such as how +to calculate the correct index for the corresponding data or the dot +multiplication implementation. However, after reading this document +and becoming familiar with the high-level logic flow, it should be +easier for you to read the actual code and understand the details. ## Inputs -- The kernel function takes a list of arguments for the current thread - to perform its assigned work. The three most important arguments are - the input pointers `q`, `k_cache`, and `v_cache`, which point - to query, key, and value data on global memory that need to be read - and processed. The output pointer `out` points to global memory - where the result should be written. These four pointers actually - refer to multi-dimensional arrays, but each thread only accesses the - portion of data assigned to it. I have omitted all other runtime - parameters here for simplicity. +The kernel function takes a list of arguments for the current thread +to perform its assigned work. The three most important arguments are +the input pointers `q`, `k_cache`, and `v_cache`, which point +to query, key, and value data on global memory that need to be read +and processed. The output pointer `out` points to global memory +where the result should be written. These four pointers actually +refer to multi-dimensional arrays, but each thread only accesses the +portion of data assigned to it. I have omitted all other runtime +parameters here for simplicity. - ```cpp - template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> - __device__ void paged_attention_kernel( - ... // Other side args. - const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - ... // Other side args. - ) - ``` +```cpp +template +__device__ void paged_attention_kernel( + ... // Other side args. + const scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + ... // Other side args. +) +``` -- There are also a list of template arguments above the function - signature that are determined during compilation time. `scalar_t` - represents the data type of the query, key, and value data elements, - such as FP16. `HEAD_SIZE` indicates the number of elements in each - head. `BLOCK_SIZE` refers to the number of tokens in each block. - `NUM_THREADS` denotes the number of threads in each thread block. - `PARTITION_SIZE` represents the number of tensor parallel GPUs (For - simplicity, we assume this is 0 and tensor parallel is disabled). +There are also a list of template arguments above the function +signature that are determined during compilation time. `scalar_t` +represents the data type of the query, key, and value data elements, +such as FP16. `HEAD_SIZE` indicates the number of elements in each +head. `BLOCK_SIZE` refers to the number of tokens in each block. +`NUM_THREADS` denotes the number of threads in each thread block. +`PARTITION_SIZE` represents the number of tensor parallel GPUs (For +simplicity, we assume this is 0 and tensor parallel is disabled). -- With these arguments, we need to perform a sequence of preparations. - This includes calculating the current head index, block index, and - other necessary variables. However, for now, we can ignore these - preparations and proceed directly to the actual calculations. It will - be easier to understand them once we grasp the entire flow. +With these arguments, we need to perform a sequence of preparations. +This includes calculating the current head index, block index, and +other necessary variables. However, for now, we can ignore these +preparations and proceed directly to the actual calculations. It will +be easier to understand them once we grasp the entire flow. ## Concepts -- Just before we dive into the calculation flow, I want to describe a - few concepts that are needed for later sections. However, you may - skip this section and return later if you encounter any confusing - terminologies. +Just before we dive into the calculation flow, I want to describe a +few concepts that are needed for later sections. However, you may +skip this section and return later if you encounter any confusing +terminologies. + - **Sequence**: A sequence represents a client request. For example, the data pointed to by `q` has a shape of `[num_seqs, num_heads, head_size]`. That represents there are total @@ -129,236 +127,236 @@ title: vLLM Paged Attention ## Query -- This section will introduce how query data is stored in memory and - fetched by each thread. As mentioned above, each thread group fetches - one query token data, while each thread itself only handles a part of - one query token data. Within each warp, every thread group will fetch - the same query token data, but will multiply it with different key - token data. +This section will introduce how query data is stored in memory and +fetched by each thread. As mentioned above, each thread group fetches +one query token data, while each thread itself only handles a part of +one query token data. Within each warp, every thread group will fetch +the same query token data, but will multiply it with different key +token data. - ```cpp - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - ``` +```cpp +const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; +```
![](../../assets/kernel/query.png){ align="center" alt="query" width="70%" }
-- Each thread defines its own `q_ptr` which points to the assigned - query token data on global memory. For example, if `VEC_SIZE` is 4 - and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains - total of 128 elements divided into 128 / 4 = 32 vecs. +Each thread defines its own `q_ptr` which points to the assigned +query token data on global memory. For example, if `VEC_SIZE` is 4 +and `HEAD_SIZE` is 128, the `q_ptr` points to data that contains +total of 128 elements divided into 128 / 4 = 32 vecs.
![](../../assets/kernel/q_vecs.png){ align="center" alt="q_vecs" width="70%" }
- ```cpp - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; - ``` +```cpp +__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +``` -- Next, we need to read the global memory data pointed to by `q_ptr` - into shared memory as `q_vecs`. It is important to note that each - vecs is assigned to a different row. For example, if the - `THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs, - while thread 1 handles the 1st row vecs. By reading the query data in - this way, neighboring threads like thread 0 and thread 1 can read - neighbor memory, achieving the memory coalescing to improve - performance. +Next, we need to read the global memory data pointed to by `q_ptr` +into shared memory as `q_vecs`. It is important to note that each +vecs is assigned to a different row. For example, if the +`THREAD_GROUP_SIZE` is 2, thread 0 will handle the 0th row vecs, +while thread 1 handles the 1st row vecs. By reading the query data in +this way, neighboring threads like thread 0 and thread 1 can read +neighbor memory, achieving the memory coalescing to improve +performance. ## Key -- Similar to the "Query" section, this section introduces memory layout - and assignment for keys. While each thread group only handle one - query token one kernel run, it may handle multiple key tokens across - multiple iterations. Meanwhile, each warp will process multiple blocks - of key tokens in multiple iterations, ensuring that all context - tokens are processed by the entire thread group after the kernel run. - In this context, "handle" refers to performing the dot multiplication - between query data and key data. +Similar to the "Query" section, this section introduces memory layout +and assignment for keys. While each thread group only handle one +query token one kernel run, it may handle multiple key tokens across +multiple iterations. Meanwhile, each warp will process multiple blocks +of key tokens in multiple iterations, ensuring that all context +tokens are processed by the entire thread group after the kernel run. +In this context, "handle" refers to performing the dot multiplication +between query data and key data. - ```cpp - const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride - + kv_head_idx * kv_head_stride - + physical_block_offset * x; - ``` +```cpp +const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; +``` -- Unlike to `q_ptr`, `k_ptr` in each thread will point to different - key token at different iterations. As shown above, that `k_ptr` - points to key token data based on `k_cache` at assigned block, - assigned head and assigned token. +Unlike to `q_ptr`, `k_ptr` in each thread will point to different +key token at different iterations. As shown above, that `k_ptr` +points to key token data based on `k_cache` at assigned block, +assigned head and assigned token.
![](../../assets/kernel/key.png){ align="center" alt="key" width="70%" }
-- The diagram above illustrates the memory layout for key data. It - assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is - 8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each - rectangle represents all the elements for one key token at one head, - which will be processed by one thread group. The left half shows the - total 16 blocks of key token data for warp 0, while the right half - represents the remaining key token data for other warps or - iterations. Inside each rectangle, there are a total 32 vecs (128 - elements for one token) that will be processed by 2 threads (one - thread group) separately. +The diagram above illustrates the memory layout for key data. It +assumes that the `BLOCK_SIZE` is 16, `HEAD_SIZE` is 128, `x` is +8, `THREAD_GROUP_SIZE` is 2, and there are a total of 4 warps. Each +rectangle represents all the elements for one key token at one head, +which will be processed by one thread group. The left half shows the +total 16 blocks of key token data for warp 0, while the right half +represents the remaining key token data for other warps or +iterations. Inside each rectangle, there are a total 32 vecs (128 +elements for one token) that will be processed by 2 threads (one +thread group) separately.
![](../../assets/kernel/k_vecs.png){ align="center" alt="k_vecs" width="70%" }
- ```cpp - K_vec k_vecs[NUM_VECS_PER_THREAD] - ``` +```cpp +K_vec k_vecs[NUM_VECS_PER_THREAD] +``` -- Next, we need to read the key token data from `k_ptr` and store - them on register memory as `k_vecs`. We use register memory for - `k_vecs` because it will only be accessed by one thread once, - whereas `q_vecs` will be accessed by multiple threads multiple - times. Each `k_vecs` will contain multiple vectors for later - calculation. Each vec will be set at each inner iteration. The - assignment of vecs allows neighboring threads in a warp to read - neighboring memory together, which again promotes the memory - coalescing. For instance, thread 0 will read vec 0, while thread 1 - will read vec 1. In the next inner loop, thread 0 will read vec 2, - while thread 1 will read vec 3, and so on. +Next, we need to read the key token data from `k_ptr` and store +them on register memory as `k_vecs`. We use register memory for +`k_vecs` because it will only be accessed by one thread once, +whereas `q_vecs` will be accessed by multiple threads multiple +times. Each `k_vecs` will contain multiple vectors for later +calculation. Each vec will be set at each inner iteration. The +assignment of vecs allows neighboring threads in a warp to read +neighboring memory together, which again promotes the memory +coalescing. For instance, thread 0 will read vec 0, while thread 1 +will read vec 1. In the next inner loop, thread 0 will read vec 2, +while thread 1 will read vec 3, and so on. -- You may still be a little confused about the overall flow. Don't - worry, please keep reading the next "QK" section. It will illustrate - the query and key calculation flow in a clearer and higher-level - manner. +You may still be a little confused about the overall flow. Don't +worry, please keep reading the next "QK" section. It will illustrate +the query and key calculation flow in a clearer and higher-level +manner. ## QK -- As shown the pseudo code below, before the entire for loop block, we - fetch the query data for one token and store it in `q_vecs`. Then, - in the outer for loop, we iterate through different `k_ptrs` that - point to different tokens and prepare the `k_vecs` in the inner for - loop. Finally, we perform the dot multiplication between the - `q_vecs` and each `k_vecs`. +As shown the pseudo code below, before the entire for loop block, we +fetch the query data for one token and store it in `q_vecs`. Then, +in the outer for loop, we iterate through different `k_ptrs` that +point to different tokens and prepare the `k_vecs` in the inner for +loop. Finally, we perform the dot multiplication between the +`q_vecs` and each `k_vecs`. - ```cpp - q_vecs = ... - for ... { - k_ptr = ... - for ... { +```cpp +q_vecs = ... +for ... { + k_ptr = ... + for ... { k_vecs[i] = ... - } - ... - float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); - } - ``` + } + ... + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); +} +``` -- As mentioned before, for each thread, it only fetches part of the - query and key token data at a time. However, there will be a cross - thread group reduction happen in the `Qk_dot<>::dot` . So `qk` - returned here is not just between part of the query and key token dot - multiplication, but actually a full result between entire query and - key token data. +As mentioned before, for each thread, it only fetches part of the +query and key token data at a time. However, there will be a cross +thread group reduction happen in the `Qk_dot<>::dot` . So `qk` +returned here is not just between part of the query and key token dot +multiplication, but actually a full result between entire query and +key token data. -- For example, if the value of `HEAD_SIZE` is 128 and - `THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain - total 64 elements. However, the returned `qk` is actually the - result of dot multiplication between 128 query elements and 128 key - elements. If you want to learn more about the details of the dot - multiplication and reduction, you may refer to the implementation of - `Qk_dot<>::dot`. However, for the sake of simplicity, I will not - cover it in this document. +For example, if the value of `HEAD_SIZE` is 128 and +`THREAD_GROUP_SIZE` is 2, each thread's `k_vecs` will contain +total 64 elements. However, the returned `qk` is actually the +result of dot multiplication between 128 query elements and 128 key +elements. If you want to learn more about the details of the dot +multiplication and reduction, you may refer to the implementation of +`Qk_dot<>::dot`. However, for the sake of simplicity, I will not +cover it in this document. ## Softmax -- Next, we need to calculate the normalized softmax for all `qk`s, - as shown above, where each $x$ represents a `qk`. To do this, - we must obtain the reduced value of `qk_max`($m(x)$) and - the `exp_sum`($\ell(x)$) of all `qk`s. The reduction - should be performed across the entire thread block, encompassing - results between the query token and all context key tokens. +Next, we need to calculate the normalized softmax for all `qk`s, +as shown above, where each $x$ represents a `qk`. To do this, +we must obtain the reduced value of `qk_max`($m(x)$) and +the `exp_sum`($\ell(x)$) of all `qk`s. The reduction +should be performed across the entire thread block, encompassing +results between the query token and all context key tokens. - $$ - \begin{gather*} - m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ - \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} - \end{gather*} - $$ +$$ +\begin{gather*} +m(x):=\max _i \quad x_i \\ \quad f(x):=\left[\begin{array}{lll}e^{x_1-m(x)} & \ldots & e^{x_B-m(x)}\end{array}\right]\\ \quad \ell(x):=\sum_i f(x)_i \\ +\quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)} +\end{gather*} +$$ ### `qk_max` and `logits` -- Just right after we get the `qk` result, we can set the temporary - `logits` result with `qk` (In the end, the `logits` should - store the normalized softmax result). Also we can compare and collect - the `qk_max` for all `qk`s that are calculated by current - thread group. +Just right after we get the `qk` result, we can set the temporary +`logits` result with `qk` (In the end, the `logits` should +store the normalized softmax result). Also we can compare and collect +the `qk_max` for all `qk`s that are calculated by current +thread group. - ```cpp - if (thread_group_offset == 0) { - const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - ``` +```cpp +if (thread_group_offset == 0) { + const bool mask = token_idx >= context_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + qk_max = mask ? qk_max : fmaxf(qk_max, qk); +} +``` -- Please note that the `logits` here is on shared memory, so each - thread group will set the fields for its own assigned context tokens. - Overall, the size of logits should be number of context tokens. +Please note that the `logits` here is on shared memory, so each +thread group will set the fields for its own assigned context tokens. +Overall, the size of logits should be number of context tokens. - ```cpp - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } +```cpp +for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +} - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - ``` +if (lane == 0) { + red_smem[warp_idx] = qk_max; +} +``` -- Then we need to get the reduced `qk_max` across each warp. The main - idea is to make threads in warp to communicate with each other and - get the final max `qk` . +Then we need to get the reduced `qk_max` across each warp. The main +idea is to make threads in warp to communicate with each other and +get the final max `qk` . - ```cpp - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - qk_max = VLLM_SHFL_SYNC(qk_max, 0); - ``` +```cpp +for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); +} +qk_max = VLLM_SHFL_SYNC(qk_max, 0); +``` -- Finally, we can get the reduced `qk_max` from whole thread block by - compare the `qk_max` from all warps in this thread block. Then we - need to broadcast the final result to each thread. +Finally, we can get the reduced `qk_max` from whole thread block by +compare the `qk_max` from all warps in this thread block. Then we +need to broadcast the final result to each thread. ### `exp_sum` -- Similar to `qk_max`, we need to get the reduced sum value from the - entire thread block too. +Similar to `qk_max`, we need to get the reduced sum value from the +entire thread block too. - ```cpp - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - ... - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - ``` +```cpp +for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; +} +... +exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); +``` -- Firstly, sum all exp values from each thread group, and meanwhile, - convert each entry of `logits` from `qk` to `exp(qk - qk_max)`. - Please note, the `qk_max` here is already the max `qk` across the - whole thread block. And then we can do reduction for `exp_sum` - across whole thread block just like the `qk_max`. +Firstly, sum all exp values from each thread group, and meanwhile, +convert each entry of `logits` from `qk` to `exp(qk - qk_max)`. +Please note, the `qk_max` here is already the max `qk` across the +whole thread block. And then we can do reduction for `exp_sum` +across whole thread block just like the `qk_max`. - ```cpp - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - ``` +```cpp +const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); +for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; +} +``` -- Finally, with the reduced `qk_max` and `exp_sum`, we can obtain - the final normalized softmax result as `logits`. This `logits` - variable will be used for dot multiplication with the value data in - later steps. Now, it should store the normalized softmax result of - `qk` for all assigned context tokens. +Finally, with the reduced `qk_max` and `exp_sum`, we can obtain +the final normalized softmax result as `logits`. This `logits` +variable will be used for dot multiplication with the value data in +later steps. Now, it should store the normalized softmax result of +`qk` for all assigned context tokens. ## Value @@ -374,127 +372,127 @@ title: vLLM Paged Attention ![](../../assets/kernel/v_vec.png){ align="center" alt="v_vec" width="70%" } -- Now we need to retrieve the value data and perform dot multiplication - with `logits`. Unlike query and key, there is no thread group - concept for value data. As shown in diagram, different from key token - memory layout, elements from the same column correspond to the same - value token. For one block of value data, there are `HEAD_SIZE` of - rows and `BLOCK_SIZE` of columns that are split into multiple - `v_vecs`. +Now we need to retrieve the value data and perform dot multiplication +with `logits`. Unlike query and key, there is no thread group +concept for value data. As shown in diagram, different from key token +memory layout, elements from the same column correspond to the same +value token. For one block of value data, there are `HEAD_SIZE` of +rows and `BLOCK_SIZE` of columns that are split into multiple +`v_vecs`. -- Each thread always fetches `V_VEC_SIZE` elements from the same - `V_VEC_SIZE` of tokens at a time. As a result, a single thread - retrieves multiple `v_vec`s from different rows and the same - columns through multiple inner iterations. For each `v_vec`, it - needs to be dot multiplied with the corresponding `logits_vec`, - which is also `V_VEC_SIZE` elements from `logits`. Overall, with - multiple inner iterations, each warp will process one block of value - tokens. And with multiple outer iterations, the whole context value - tokens are processed +Each thread always fetches `V_VEC_SIZE` elements from the same +`V_VEC_SIZE` of tokens at a time. As a result, a single thread +retrieves multiple `v_vec`s from different rows and the same +columns through multiple inner iterations. For each `v_vec`, it +needs to be dot multiplied with the corresponding `logits_vec`, +which is also `V_VEC_SIZE` elements from `logits`. Overall, with +multiple inner iterations, each warp will process one block of value +tokens. And with multiple outer iterations, the whole context value +tokens are processed - ```cpp - float accs[NUM_ROWS_PER_THREAD]; - for ... { // Iteration over different blocks. - logits_vec = ... - for ... { // Iteration over different rows. - v_vec = ... - ... - accs[i] += dot(logits_vec, v_vec); - } - } - ``` +```cpp +float accs[NUM_ROWS_PER_THREAD]; +for ... { // Iteration over different blocks. + logits_vec = ... + for ... { // Iteration over different rows. + v_vec = ... + ... + accs[i] += dot(logits_vec, v_vec); + } +} +``` -- As shown in the above pseudo code, in the outer loop, similar to - `k_ptr`, `logits_vec` iterates over different blocks and reads - `V_VEC_SIZE` elements from `logits`. In the inner loop, each - thread reads `V_VEC_SIZE` elements from the same tokens as a - `v_vec` and performs dot multiplication. It is important to note - that in each inner iteration, the thread fetches different head - position elements for the same tokens. The dot result is then - accumulated in `accs`. Therefore, each entry of `accs` is mapped - to a head position assigned to the current thread. +As shown in the above pseudo code, in the outer loop, similar to +`k_ptr`, `logits_vec` iterates over different blocks and reads +`V_VEC_SIZE` elements from `logits`. In the inner loop, each +thread reads `V_VEC_SIZE` elements from the same tokens as a +`v_vec` and performs dot multiplication. It is important to note +that in each inner iteration, the thread fetches different head +position elements for the same tokens. The dot result is then +accumulated in `accs`. Therefore, each entry of `accs` is mapped +to a head position assigned to the current thread. -- For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each - thread fetches 8 value elements for 8 tokens at a time. Each element - is from different tokens at the same head position. If `HEAD_SIZE` - is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to - fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are - a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle - a whole block of value tokens. And each `accs` in each thread - contains 8 elements that accumulated at 8 different head positions. - For the thread 0, the `accs` variable will have 8 elements, which - are 0th, 32th … 224th elements of a value head that are accumulated - from all assigned 8 tokens. +For example, if `BLOCK_SIZE` is 16 and `V_VEC_SIZE` is 8, each +thread fetches 8 value elements for 8 tokens at a time. Each element +is from different tokens at the same head position. If `HEAD_SIZE` +is 128 and `WARP_SIZE` is 32, for each inner loop, a warp needs to +fetch `WARP_SIZE * V_VEC_SIZE = 256` elements. This means there are +a total of 128 * 16 / 256 = 8 inner iterations for a warp to handle +a whole block of value tokens. And each `accs` in each thread +contains 8 elements that accumulated at 8 different head positions. +For the thread 0, the `accs` variable will have 8 elements, which +are 0th, 32th … 224th elements of a value head that are accumulated +from all assigned 8 tokens. ## LV -- Now, we need to perform reduction for `accs` within each warp. This - process allows each thread to accumulate the `accs` for the - assigned head positions of all tokens in one block. +Now, we need to perform reduction for `accs` within each warp. This +process allows each thread to accumulate the `accs` for the +assigned head positions of all tokens in one block. - ```cpp - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +```cpp +for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { acc += VLLM_SHFL_XOR_SYNC(acc, mask); - } - accs[i] = acc; - } - ``` + } + accs[i] = acc; +} +``` -- Next, we perform reduction for `accs` across all warps, allowing - each thread to have the accumulation of `accs` for the assigned - head positions of all context tokens. Please note that each `accs` - in every thread only stores the accumulation for a portion of - elements of the entire head for all context tokens. However, overall, - all results for output have been calculated but are just stored in - different thread register memory. +Next, we perform reduction for `accs` across all warps, allowing +each thread to have the accumulation of `accs` for the assigned +head positions of all context tokens. Please note that each `accs` +in every thread only stores the accumulation for a portion of +elements of the entire head for all context tokens. However, overall, +all results for output have been calculated but are just stored in +different thread register memory. - ```cpp - float* out_smem = reinterpret_cast(shared_mem); - for (int i = NUM_WARPS; i > 1; i /= 2) { - // Upper warps write to shared memory. - ... - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - ... - dst[row_idx] = accs[i]; - } +```cpp +float* out_smem = reinterpret_cast(shared_mem); +for (int i = NUM_WARPS; i > 1; i /= 2) { + // Upper warps write to shared memory. + ... + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + dst[row_idx] = accs[i]; + } - // Lower warps update the output. - const float* src = &out_smem[warp_idx * HEAD_SIZE]; - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - ... - accs[i] += src[row_idx]; - } + // Lower warps update the output. + const float* src = &out_smem[warp_idx * HEAD_SIZE]; + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + ... + accs[i] += src[row_idx]; + } - // Write out the accs. - } - ``` + // Write out the accs. +} +``` ## Output -- Now we can write all of calculated result from local register memory - to final output global memory. +Now we can write all of calculated result from local register memory +to final output global memory. - ```cpp - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; - ``` +```cpp +scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; +``` -- First, we need to define the `out_ptr` variable, which points to - the start address of the assigned sequence and assigned head. +First, we need to define the `out_ptr` variable, which points to +the start address of the assigned sequence and assigned head. - ```cpp - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - ``` +```cpp +for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } +} +``` -- Finally, we need to iterate over different assigned head positions - and write out the corresponding accumulated result based on the - `out_ptr`. +Finally, we need to iterate over different assigned head positions +and write out the corresponding accumulated result based on the +`out_ptr`.