[V1] [Hybrid] Enable Full CUDA graph by default for hybrid models in V1 (#22594)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-08-27 01:28:55 +02:00 committed by GitHub
parent c3b0fd1ee6
commit 5f1af97f86
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ from copy import deepcopy
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.config.compilation import CUDAGraphMode
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
@ -275,6 +276,42 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
"%d for performance.", 1024)
class MambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
to get good performance for mamba layers in V1).
Args:
vllm_config: vLLM Config
"""
if not envs.VLLM_USE_V1:
return
model_config = vllm_config.model_config
compilation_config = vllm_config.compilation_config
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,
model_config=model_config,
)
# TODO(tdoublep): remove as full cuda graph support is added
FCG_NOT_SUPPORTED_MODELS = [
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
]
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
and compilation_config.cudagraph_mode is None):
logger.info(
"Hybrid or mamba-based model detected: setting cudagraph mode "
"to FULL_AND_PIECEWISE in order to optimize performance.")
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
@ -293,6 +330,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if not envs.VLLM_USE_V1:
return
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
@ -374,4 +414,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,
}