mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 00:46:32 +08:00
[V0 Deprecation] Remove max_seq_len_to_capture (#25543)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
7441d07360
commit
9914857f2b
@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||
max_model_len=256,
|
||||
max_seq_len_to_capture=256,
|
||||
max_num_seqs=8,
|
||||
tensor_parallel_size=tp,
|
||||
enable_lora=True,
|
||||
|
||||
@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
max_decode_seq_len=self.runner.max_model_len,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
|
||||
dtype=torch.int).cuda()
|
||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
||||
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
|
||||
attn_metadata.num_encoder_tokens = 0
|
||||
|
||||
def _add_additional_input_buffers_for_enc_dec_model(
|
||||
|
||||
@ -177,11 +177,6 @@ class ModelConfig:
|
||||
graph and always execute the model in eager mode. If False, we will use
|
||||
CUDA graph and eager execution in hybrid for maximal performance and
|
||||
flexibility."""
|
||||
max_seq_len_to_capture: int = 8192
|
||||
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
||||
length larger than this, we fall back to eager mode. Additionally for
|
||||
encoder-decoder models, if the sequence length of the encoder input is
|
||||
larger than this, we fall back to the eager mode."""
|
||||
max_logprobs: int = 20
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
@ -1024,21 +1019,8 @@ class ModelConfig:
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
# The `max_seq_len_to_capture` was incorrectly
|
||||
# based on the encoder's input length (448)
|
||||
# but not the decoder's larger input length (1500).
|
||||
# This change ensures the CUDA Graph captures the correct,
|
||||
# larger sequence length, allowing it to work as intended.
|
||||
effective_max_seq_len = self.max_model_len
|
||||
if self.is_encoder_decoder:
|
||||
effective_max_seq_len = max(
|
||||
effective_max_seq_len,
|
||||
getattr(self.hf_config, "max_source_positions", 0))
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
effective_max_seq_len)
|
||||
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
||||
unsupported_rocm = self.is_encoder_decoder
|
||||
|
||||
if (unsupported_rocm and not self.enforce_eager
|
||||
and current_platform.is_rocm()):
|
||||
logger.warning(
|
||||
|
||||
@ -285,8 +285,6 @@ class SpeculativeConfig:
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=self.target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
@ -373,7 +373,6 @@ class EngineArgs:
|
||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
@ -545,8 +544,6 @@ class EngineArgs:
|
||||
**model_kwargs["quantization"])
|
||||
model_group.add_argument("--enforce-eager",
|
||||
**model_kwargs["enforce_eager"])
|
||||
model_group.add_argument("--max-seq-len-to-capture",
|
||||
**model_kwargs["max_seq_len_to_capture"])
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--logprobs-mode",
|
||||
@ -1008,7 +1005,6 @@ class EngineArgs:
|
||||
max_model_len=self.max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.enforce_eager,
|
||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
||||
max_logprobs=self.max_logprobs,
|
||||
logprobs_mode=self.logprobs_mode,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
|
||||
@ -130,11 +130,6 @@ class LLM:
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
disable_custom_all_reduce: See
|
||||
[ParallelConfig][vllm.config.ParallelConfig].
|
||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||
@ -184,7 +179,6 @@ class LLM:
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
@ -281,7 +275,6 @@ class LLM:
|
||||
swap_space=swap_space,
|
||||
cpu_offload_gb=cpu_offload_gb,
|
||||
enforce_eager=enforce_eager,
|
||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
hf_token=hf_token,
|
||||
hf_overrides=hf_overrides,
|
||||
|
||||
@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
||||
}
|
||||
|
||||
|
||||
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
config = vllm_config.model_config
|
||||
config.max_seq_len_to_capture = config.max_model_len
|
||||
logger.info(
|
||||
"Setting max_seq_len_to_capture to %d "
|
||||
"to ensure that CUDA graph capture "
|
||||
"covers sequences of length up to max_model_len.",
|
||||
config.max_model_len)
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
|
||||
@staticmethod
|
||||
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"MambaForCausalLM": MambaModelConfig,
|
||||
"Mamba2ForCausalLM": MambaModelConfig,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user