diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index c6f914febc0a2..9fb194767e4a4 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -6,14 +6,22 @@ import torch def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): - from flashinfer import next_positive_power_of_2 - # Guess tokens per expert assuming perfect expert distribution first. - num_tokens_per_expert = (num_tokens * top_k) // num_experts - # And pad the number to the next power of 2. - tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) - # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel. - tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + # FlashInfer 0.2.10 has issues with larger tile sizes. Set to 8 for now. + # TODO: Revert this to dynamic calculation once a new version of FlashInfer + # with the necessary kernels is released. + tile_tokens_dim = 8 + + # from flashinfer import next_positive_power_of_2 + + # # Guess tokens per expert assuming perfect expert distribution first. + # num_tokens_per_expert = (num_tokens * top_k) // num_experts + # # And pad the number to the next power of 2. + # tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert) + # # Cap to 8-64 tokens per CTA tile as it's the range supported by the + # # kernel. + # tile_tokens_dim = min(max(tile_tokens_dim, 8), 64) + return tile_tokens_dim diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 1fcb190286329..c85d8bce31f5d 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -524,7 +524,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): head_dim = self.kv_cache_spec.head_size # currently prefill trtllm attention does not support fp8 kv cache - prefill_use_trtllm = use_trtllm_attention( + prefill_use_trtllm = not cache_dtype.startswith("fp8") \ + and use_trtllm_attention( num_prefill_tokens, max_seq_len, cache_dtype, num_qo_heads, num_kv_heads, head_dim) decode_use_trtllm = use_trtllm_attention(