mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-28 22:43:07 +08:00
[V1] [Hybrid] Enable Full CUDA graph by default for hybrid models in V1 (#22594)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
c3b0fd1ee6
commit
5f1af97f86
@ -4,6 +4,7 @@ from copy import deepcopy
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
@ -275,6 +276,42 @@ class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
"%d for performance.", 1024)
|
||||
|
||||
|
||||
class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Enable FULL_AND_PIECEWISE cuda graph mode by default (required
|
||||
to get good performance for mamba layers in V1).
|
||||
|
||||
Args:
|
||||
vllm_config: vLLM Config
|
||||
"""
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
return
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# TODO(tdoublep): remove as full cuda graph support is added
|
||||
FCG_NOT_SUPPORTED_MODELS = [
|
||||
"Lfm2ForCausalLM", "MiniMaxText01ForCausalLM"
|
||||
]
|
||||
|
||||
if (model_config.architecture not in FCG_NOT_SUPPORTED_MODELS
|
||||
and compilation_config.cudagraph_mode is None):
|
||||
logger.info(
|
||||
"Hybrid or mamba-based model detected: setting cudagraph mode "
|
||||
"to FULL_AND_PIECEWISE in order to optimize performance.")
|
||||
compilation_config.cudagraph_mode = CUDAGraphMode.FULL_AND_PIECEWISE
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
@ -293,6 +330,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return
|
||||
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
@ -374,4 +414,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user