mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 11:45:22 +08:00
[UX] Rename CUTLASS_MLA_VLLM_V1 to CUTLASS_MLA (#21966)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
ec02e536df
commit
61445453df
@ -1417,7 +1417,7 @@ class EngineArgs:
|
|||||||
"PALLAS_VLLM_V1",
|
"PALLAS_VLLM_V1",
|
||||||
"TRITON_ATTN_VLLM_V1",
|
"TRITON_ATTN_VLLM_V1",
|
||||||
"TRITON_MLA",
|
"TRITON_MLA",
|
||||||
"CUTLASS_MLA_VLLM_V1",
|
"CUTLASS_MLA",
|
||||||
"FLASHMLA",
|
"FLASHMLA",
|
||||||
"FLASHINFER",
|
"FLASHINFER",
|
||||||
"FLASHINFER_VLLM_V1",
|
"FLASHINFER_VLLM_V1",
|
||||||
|
|||||||
@ -162,7 +162,7 @@ class CudaPlatformBase(Platform):
|
|||||||
if cls.is_device_capability(100):
|
if cls.is_device_capability(100):
|
||||||
# Blackwell => Force CutlassMLA.
|
# Blackwell => Force CutlassMLA.
|
||||||
use_cutlass_mla = True
|
use_cutlass_mla = True
|
||||||
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA_VLLM_V1"
|
envs.VLLM_ATTENTION_BACKEND = "CUTLASS_MLA"
|
||||||
else:
|
else:
|
||||||
# Not Blackwell
|
# Not Blackwell
|
||||||
use_flashmla = True
|
use_flashmla = True
|
||||||
@ -170,7 +170,7 @@ class CudaPlatformBase(Platform):
|
|||||||
# Forced case
|
# Forced case
|
||||||
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
use_flashmla = (envs.VLLM_ATTENTION_BACKEND == "FLASHMLA")
|
||||||
use_cutlass_mla = (
|
use_cutlass_mla = (
|
||||||
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA_VLLM_V1")
|
envs.VLLM_ATTENTION_BACKEND == "CUTLASS_MLA")
|
||||||
|
|
||||||
from vllm.attention.ops.flashmla import is_flashmla_supported
|
from vllm.attention.ops.flashmla import is_flashmla_supported
|
||||||
if use_flashmla and is_flashmla_supported()[0] \
|
if use_flashmla and is_flashmla_supported()[0] \
|
||||||
@ -182,7 +182,7 @@ class CudaPlatformBase(Platform):
|
|||||||
if use_cutlass_mla and cache_config.block_size != 128:
|
if use_cutlass_mla and cache_config.block_size != 128:
|
||||||
cache_config.block_size = 128
|
cache_config.block_size = 128
|
||||||
logger.info("Forcing kv cache block size to 128 for "
|
logger.info("Forcing kv cache block size to 128 for "
|
||||||
"CUTLASS_MLA_VLLM_V1 backend.")
|
"CUTLASS_MLA backend.")
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput"
|
||||||
@ -211,9 +211,9 @@ class CudaPlatformBase(Platform):
|
|||||||
kv_cache_dtype, block_size, use_v1,
|
kv_cache_dtype, block_size, use_v1,
|
||||||
use_mla) -> str:
|
use_mla) -> str:
|
||||||
if use_mla:
|
if use_mla:
|
||||||
# TODO(lucas): refactor to be more concise
|
# TODO(lucas): refactor to be more concise
|
||||||
# we should probably consider factoring out V1 here
|
# we should probably consider factoring out V1 here
|
||||||
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
|
if selected_backend == _Backend.CUTLASS_MLA:
|
||||||
if use_v1:
|
if use_v1:
|
||||||
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
logger.info_once("Using Cutlass MLA backend on V1 engine.")
|
||||||
return ("vllm.v1.attention.backends.mla."
|
return ("vllm.v1.attention.backends.mla."
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class _Backend(enum.Enum):
|
|||||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||||
FLASHMLA_VLLM_V1 = enum.auto()
|
FLASHMLA_VLLM_V1 = enum.auto()
|
||||||
FLASHMLA = enum.auto() # Supported by V1
|
FLASHMLA = enum.auto() # Supported by V1
|
||||||
CUTLASS_MLA_VLLM_V1 = enum.auto()
|
CUTLASS_MLA = enum.auto()
|
||||||
PALLAS = enum.auto()
|
PALLAS = enum.auto()
|
||||||
PALLAS_VLLM_V1 = enum.auto()
|
PALLAS_VLLM_V1 = enum.auto()
|
||||||
IPEX = enum.auto()
|
IPEX = enum.auto()
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class CutlassMLABackend(MLACommonBackend):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
return "CUTLASS_MLA_VLLM_V1"
|
return "CUTLASS_MLA"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user