mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-31 13:16:31 +08:00
[FlashInfer] Avoid FlashInfer block_size 16 + head_size 256 on blackwell (#27994)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
002b07c4b2
commit
c765f0b443
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec
|
||||
@ -356,6 +357,17 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
).page_size_bytes
|
||||
else:
|
||||
kernel_block_alignment_size = 16
|
||||
if (
|
||||
current_platform.is_device_capability(100)
|
||||
and model_config.get_head_size() == 256
|
||||
and (
|
||||
envs.VLLM_ATTENTION_BACKEND is None
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHINFER"
|
||||
)
|
||||
):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that`
|
||||
# head size 256 and block size 16 is not supported on blackwell.
|
||||
kernel_block_alignment_size = 32
|
||||
attn_page_size_1_token = FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
|
||||
@ -402,6 +402,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
self.paged_kv_last_page_len_np = self.paged_kv_last_page_len_cpu.numpy()
|
||||
|
||||
if self.head_dim == 256 and current_platform.is_device_capability(100):
|
||||
# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
|
||||
# head size 256 and block size 16 is not supported on blackwell.
|
||||
assert kv_cache_spec.block_size != 16, (
|
||||
"There is a bug in FlashInfer "
|
||||
"block_size 16 head size 256 support. Please avoid this combination by "
|
||||
"passing --block-size 32 or --block-size 64."
|
||||
)
|
||||
|
||||
def _get_workspace_buffer(self):
|
||||
if self._workspace_buffer is None:
|
||||
buffer_size = FLASHINFER_WORKSPACE_BUFFER_SIZE
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user