mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:04:57 +08:00
[Bugfix][Attention] Fix FlashInfer MLA block size logic (#24692)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
parent
7a70a71892
commit
d4fd2768ef
@ -146,6 +146,7 @@ class CudaPlatformBase(Platform):
|
|||||||
# required block_size.
|
# required block_size.
|
||||||
use_flashmla = False
|
use_flashmla = False
|
||||||
use_cutlass_mla = False
|
use_cutlass_mla = False
|
||||||
|
use_flashinfer_mla = False
|
||||||
|
|
||||||
if envs.VLLM_ATTENTION_BACKEND is None:
|
if envs.VLLM_ATTENTION_BACKEND is None:
|
||||||
# Default case
|
# Default case
|
||||||
@ -164,6 +165,8 @@ class CudaPlatformBase(Platform):
|
|||||||
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
||||||
use_cutlass_mla = (
|
use_cutlass_mla = (
|
||||||
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
|
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
|
||||||
|
use_flashinfer_mla = (
|
||||||
|
envs.VLLM_ATTENTION_BACKEND == "FLASHINFER_MLA")
|
||||||
|
|
||||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||||
if use_flashmla and is_flashmla_supported()[0] \
|
if use_flashmla and is_flashmla_supported()[0] \
|
||||||
@ -176,6 +179,11 @@ class CudaPlatformBase(Platform):
|
|||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
logger.info("Forcing kv cache block size to 128 for "
|
logger.info("Forcing kv cache block size to 128 for "
|
||||||
"CUTLASS_MLA backend.")
|
"CUTLASS_MLA backend.")
|
||||||
|
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
||||||
|
cache_config.block_size = 64
|
||||||
|
logger.info(
|
||||||
|
"Forcing kv cache block size to 64 for FlashInferMLA "
|
||||||
|
"backend.")
|
||||||
|
|
||||||
# lazy import to avoid circular import
|
# lazy import to avoid circular import
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CUDAGraphMode
|
||||||
@ -228,8 +236,9 @@ class CudaPlatformBase(Platform):
|
|||||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||||
selected_backend is None and cls.is_device_capability(100)
|
selected_backend is None and cls.is_device_capability(100)
|
||||||
and block_size == 128)
|
and block_size == 128)
|
||||||
use_flashinfermla = (selected_backend == _Backend.FLASHINFER_MLA
|
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||||
and cls.has_device_capability(100))
|
selected_backend is None and cls.is_device_capability(100)
|
||||||
|
and block_size in [32, 64])
|
||||||
use_flashmla = selected_backend in [
|
use_flashmla = selected_backend in [
|
||||||
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
_Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1
|
||||||
] or (selected_backend is None and is_flashmla_supported()[0])
|
] or (selected_backend is None and is_flashmla_supported()[0])
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user