[V0] Correct CUDA Graph capture for encoder-decoder models (#22630)

This commit is contained in:
Sugar-zsg 2025-08-12 17:01:08 +08:00 committed by GitHub
parent 9f909b8996
commit 8d17fa633e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1164,8 +1164,18 @@ class ModelConfig:
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
# The `max_seq_len_to_capture` was incorrectly
# based on the encoder's input length (448)
# but not the decoder's larger input length (1500).
# This change ensures the CUDA Graph captures the correct,
# larger sequence length, allowing it to work as intended.
effective_max_seq_len = self.max_model_len
if self.is_encoder_decoder:
effective_max_seq_len = max(
effective_max_seq_len,
getattr(self.hf_config, "max_source_positions", 0))
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
effective_max_seq_len)
# CUDAGraph capture not supported for enc-dec models and mllama on ROCm
ROCM_UNSUPPORTED_MODELS = ['mllama']
unsupported_rocm = (self.hf_config.model_type