mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 18:09:09 +08:00
[bugfix] Fix Llama3/4 issues caused by FlashInfer 0.2.10 (#22426)
Signed-off-by: Po-Han Huang <pohanh@nvidia.com>
This commit is contained in:
parent
157f9c1368
commit
af473f0a85
@ -6,14 +6,22 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
def calculate_tile_tokens_dim(num_tokens, top_k, num_experts):
|
||||||
from flashinfer import next_positive_power_of_2
|
|
||||||
|
|
||||||
# Guess tokens per expert assuming perfect expert distribution first.
|
# FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now.
|
||||||
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
# TODO: Revert this to dynamic calculation once a new version of FlashInfer
|
||||||
# And pad the number to the next power of 2.
|
# with the necessary kernels is released.
|
||||||
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
tile_tokens_dim = 8
|
||||||
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
|
||||||
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
# from flashinfer import next_positive_power_of_2
|
||||||
|
|
||||||
|
# # Guess tokens per expert assuming perfect expert distribution first.
|
||||||
|
# num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
||||||
|
# # And pad the number to the next power of 2.
|
||||||
|
# tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
|
||||||
|
# # Cap to 8-64 tokens per CTA tile as it's the range supported by the
|
||||||
|
# # kernel.
|
||||||
|
# tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
||||||
|
|
||||||
return tile_tokens_dim
|
return tile_tokens_dim
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
head_dim = self.kv_cache_spec.head_size
|
head_dim = self.kv_cache_spec.head_size
|
||||||
|
|
||||||
# currently prefill trtllm attention does not support fp8 kv cache
|
# currently prefill trtllm attention does not support fp8 kv cache
|
||||||
prefill_use_trtllm = use_trtllm_attention(
|
prefill_use_trtllm = not cache_dtype.startswith("fp8") \
|
||||||
|
and use_trtllm_attention(
|
||||||
num_prefill_tokens, max_seq_len, cache_dtype,
|
num_prefill_tokens, max_seq_len, cache_dtype,
|
||||||
num_qo_heads, num_kv_heads, head_dim)
|
num_qo_heads, num_kv_heads, head_dim)
|
||||||
decode_use_trtllm = use_trtllm_attention(
|
decode_use_trtllm = use_trtllm_attention(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user