[Model] Add ModelConfig class for GraniteMoeHybrid to override default max_seq_len_to_capture (#20923)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-07-16 04:19:10 +02:00 committed by GitHub
parent 153c6f1e61
commit 6cbc4d4bea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -205,6 +205,19 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config
config.max_seq_len_to_capture = config.max_model_len
logger.info(
"Setting max_seq_len_to_capture to %d "
"to ensure that CUDA graph capture "
"covers sequences of length up to max_model_len.",
config.max_model_len)
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
@ -297,4 +310,5 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
"XLMRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
}