mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-02 17:37:05 +08:00
[Attention] Default to FlashMLA backend for MLA (#14451)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
5f0b53c6ea
commit
b0d541947a
@ -112,6 +112,7 @@ class CudaPlatformBase(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
@ -142,14 +143,21 @@ class CudaPlatformBase(Platform):
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 16
|
||||
|
||||
# TODO(lucas): handle this more gracefully
|
||||
if envs.VLLM_ATTENTION_BACKEND is not None \
|
||||
and envs.VLLM_ATTENTION_BACKEND == "FLASHMLA" \
|
||||
and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"FlashMLA: Forcing kv cache block size to 64 since this"
|
||||
" is currently the only block size supported by the kernel.")
|
||||
# Note: model_config may be None during testing
|
||||
if model_config is not None and model_config.use_mla:
|
||||
# if `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, then
|
||||
# we default to FlashMLA backend, so we need to force the blocksize
|
||||
# here
|
||||
use_flashmla = (envs.VLLM_ATTENTION_BACKEND is None \
|
||||
or envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
||||
from vllm.attention.backends.flashmla import is_flashmla_supported
|
||||
if use_flashmla and is_flashmla_supported()[0] \
|
||||
and cache_config.block_size != 64:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
"Forcing kv cache block size to 64 for FlashMLA backend.")
|
||||
|
||||
if (parallel_config.data_parallel_size > 1
|
||||
and compilation_config.use_cudagraph):
|
||||
@ -173,7 +181,15 @@ class CudaPlatformBase(Platform):
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.FLASHMLA:
|
||||
if selected_backend == _Backend.TRITON_MLA or block_size != 64:
|
||||
if use_v1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
else:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
if not is_flashmla_supported()[0]:
|
||||
@ -195,14 +211,6 @@ class CudaPlatformBase(Platform):
|
||||
logger.info("Using FlashMLA backend.")
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
|
||||
if use_v1:
|
||||
logger.info_once("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
if use_v1:
|
||||
logger.info_once("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.flash_attn."
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user