mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:17:15 +08:00
[Model] Add ModelConfig class for GraniteMoeHybrid to override default max_seq_len_to_capture (#20923)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
153c6f1e61
commit
6cbc4d4bea
@ -205,6 +205,19 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config
|
||||
config.max_seq_len_to_capture = config.max_model_len
|
||||
logger.info(
|
||||
"Setting max_seq_len_to_capture to %d "
|
||||
"to ensure that CUDA graph capture "
|
||||
"covers sequences of length up to max_model_len.",
|
||||
config.max_model_len)
|
||||
|
||||
|
||||
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@classmethod
|
||||
@ -297,4 +310,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user