[V0 Deprecation] Remove max_seq_len_to_capture (#25543)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Woosuk Kwon 2025-09-24 01:51:39 -07:00 committed by yewentao256
parent 7441d07360
commit 9914857f2b
7 changed files with 2 additions and 48 deletions

View File

@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
max_model_len=256,
max_seq_len_to_capture=256,
max_num_seqs=8,
tensor_parallel_size=tp,
enable_lora=True,

View File

@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
max_decode_seq_len=self.runner.max_model_len,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
attn_metadata.num_encoder_tokens = 0
def _add_additional_input_buffers_for_enc_dec_model(

View File

@ -177,11 +177,6 @@ class ModelConfig:
graph and always execute the model in eager mode. If False, we will use
CUDA graph and eager execution in hybrid for maximal performance and
flexibility."""
max_seq_len_to_capture: int = 8192
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
length larger than this, we fall back to eager mode. Additionally for
encoder-decoder models, if the sequence length of the encoder input is
larger than this, we fall back to the eager mode."""
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
@ -1024,21 +1019,8 @@ class ModelConfig:
current_platform.verify_quantization(self.quantization)
def _verify_cuda_graph(self) -> None:
# The `max_seq_len_to_capture` was incorrectly
# based on the encoder's input length (448)
# but not the decoder's larger input length (1500).
# This change ensures the CUDA Graph captures the correct,
# larger sequence length, allowing it to work as intended.
effective_max_seq_len = self.max_model_len
if self.is_encoder_decoder:
effective_max_seq_len = max(
effective_max_seq_len,
getattr(self.hf_config, "max_source_positions", 0))
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
effective_max_seq_len)
# CUDAGraph capture not supported for encoder-decoder models on ROCm
unsupported_rocm = self.is_encoder_decoder
if (unsupported_rocm and not self.enforce_eager
and current_platform.is_rocm()):
logger.warning(

View File

@ -285,8 +285,6 @@ class SpeculativeConfig:
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_seq_len_to_capture=self.target_model_config.
max_seq_len_to_capture,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)

View File

@ -373,7 +373,6 @@ class EngineArgs:
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
@ -545,8 +544,6 @@ class EngineArgs:
**model_kwargs["quantization"])
model_group.add_argument("--enforce-eager",
**model_kwargs["enforce_eager"])
model_group.add_argument("--max-seq-len-to-capture",
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--logprobs-mode",
@ -1008,7 +1005,6 @@ class EngineArgs:
max_model_len=self.max_model_len,
quantization=self.quantization,
enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_logprobs=self.max_logprobs,
logprobs_mode=self.logprobs_mode,
disable_sliding_window=self.disable_sliding_window,

View File

@ -130,11 +130,6 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
hf_token: The token to use as HTTP bearer authorization for remote files
@ -184,7 +179,6 @@ class LLM:
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
@ -281,7 +275,6 @@ class LLM:
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
enforce_eager=enforce_eager,
max_seq_len_to_capture=max_seq_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
hf_token=hf_token,
hf_overrides=hf_overrides,

View File

@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
config = vllm_config.model_config
config.max_seq_len_to_capture = config.max_model_len
logger.info(
"Setting max_seq_len_to_capture to %d "
"to ensure that CUDA graph capture "
"covers sequences of length up to max_model_len.",
config.max_model_len)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"XLMRobertaModel": JinaRobertaModelConfig,
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"MambaForCausalLM": MambaModelConfig,
"Mamba2ForCausalLM": MambaModelConfig,