mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-11 01:42:15 +08:00
[V0 Deprecation] Remove max_seq_len_to_capture (#25543)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
77a7fce1bb
commit
2e19a848d4
@ -31,7 +31,6 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
|
|||||||
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
def setup_vllm(num_loras: int, tp: int) -> vllm.LLM:
|
||||||
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
|
||||||
max_model_len=256,
|
max_model_len=256,
|
||||||
max_seq_len_to_capture=256,
|
|
||||||
max_num_seqs=8,
|
max_num_seqs=8,
|
||||||
tensor_parallel_size=tp,
|
tensor_parallel_size=tp,
|
||||||
enable_lora=True,
|
enable_lora=True,
|
||||||
|
|||||||
@ -304,7 +304,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
max_query_len=1,
|
max_query_len=1,
|
||||||
max_decode_query_len=1,
|
max_decode_query_len=1,
|
||||||
max_prefill_seq_len=0,
|
max_prefill_seq_len=0,
|
||||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
max_decode_seq_len=self.runner.max_model_len,
|
||||||
query_start_loc=None,
|
query_start_loc=None,
|
||||||
seq_start_loc=None,
|
seq_start_loc=None,
|
||||||
context_lens_tensor=None,
|
context_lens_tensor=None,
|
||||||
@ -390,7 +390,7 @@ class CommonAttentionState(AttentionState):
|
|||||||
dtype=torch.int).cuda()
|
dtype=torch.int).cuda()
|
||||||
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
attn_metadata.encoder_seq_lens_tensor = torch.full(
|
||||||
(batch_size, ), 1, dtype=torch.int).cuda()
|
(batch_size, ), 1, dtype=torch.int).cuda()
|
||||||
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
|
attn_metadata.max_encoder_seq_len = self.runner.max_model_len
|
||||||
attn_metadata.num_encoder_tokens = 0
|
attn_metadata.num_encoder_tokens = 0
|
||||||
|
|
||||||
def _add_additional_input_buffers_for_enc_dec_model(
|
def _add_additional_input_buffers_for_enc_dec_model(
|
||||||
|
|||||||
@ -177,11 +177,6 @@ class ModelConfig:
|
|||||||
graph and always execute the model in eager mode. If False, we will use
|
graph and always execute the model in eager mode. If False, we will use
|
||||||
CUDA graph and eager execution in hybrid for maximal performance and
|
CUDA graph and eager execution in hybrid for maximal performance and
|
||||||
flexibility."""
|
flexibility."""
|
||||||
max_seq_len_to_capture: int = 8192
|
|
||||||
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
|
||||||
length larger than this, we fall back to eager mode. Additionally for
|
|
||||||
encoder-decoder models, if the sequence length of the encoder input is
|
|
||||||
larger than this, we fall back to the eager mode."""
|
|
||||||
max_logprobs: int = 20
|
max_logprobs: int = 20
|
||||||
"""Maximum number of log probabilities to return when `logprobs` is
|
"""Maximum number of log probabilities to return when `logprobs` is
|
||||||
specified in `SamplingParams`. The default value comes the default for the
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
@ -1024,21 +1019,8 @@ class ModelConfig:
|
|||||||
current_platform.verify_quantization(self.quantization)
|
current_platform.verify_quantization(self.quantization)
|
||||||
|
|
||||||
def _verify_cuda_graph(self) -> None:
|
def _verify_cuda_graph(self) -> None:
|
||||||
# The `max_seq_len_to_capture` was incorrectly
|
|
||||||
# based on the encoder's input length (448)
|
|
||||||
# but not the decoder's larger input length (1500).
|
|
||||||
# This change ensures the CUDA Graph captures the correct,
|
|
||||||
# larger sequence length, allowing it to work as intended.
|
|
||||||
effective_max_seq_len = self.max_model_len
|
|
||||||
if self.is_encoder_decoder:
|
|
||||||
effective_max_seq_len = max(
|
|
||||||
effective_max_seq_len,
|
|
||||||
getattr(self.hf_config, "max_source_positions", 0))
|
|
||||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
|
||||||
effective_max_seq_len)
|
|
||||||
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
# CUDAGraph capture not supported for encoder-decoder models on ROCm
|
||||||
unsupported_rocm = self.is_encoder_decoder
|
unsupported_rocm = self.is_encoder_decoder
|
||||||
|
|
||||||
if (unsupported_rocm and not self.enforce_eager
|
if (unsupported_rocm and not self.enforce_eager
|
||||||
and current_platform.is_rocm()):
|
and current_platform.is_rocm()):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@ -285,8 +285,6 @@ class SpeculativeConfig:
|
|||||||
max_model_len,
|
max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
enforce_eager=self.target_model_config.enforce_eager,
|
enforce_eager=self.target_model_config.enforce_eager,
|
||||||
max_seq_len_to_capture=self.target_model_config.
|
|
||||||
max_seq_len_to_capture,
|
|
||||||
max_logprobs=self.target_model_config.max_logprobs,
|
max_logprobs=self.target_model_config.max_logprobs,
|
||||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -373,7 +373,6 @@ class EngineArgs:
|
|||||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||||
enforce_eager: bool = ModelConfig.enforce_eager
|
enforce_eager: bool = ModelConfig.enforce_eager
|
||||||
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
|
||||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||||
limit_mm_per_prompt: dict[str, int] = \
|
limit_mm_per_prompt: dict[str, int] = \
|
||||||
get_field(MultiModalConfig, "limit_per_prompt")
|
get_field(MultiModalConfig, "limit_per_prompt")
|
||||||
@ -545,8 +544,6 @@ class EngineArgs:
|
|||||||
**model_kwargs["quantization"])
|
**model_kwargs["quantization"])
|
||||||
model_group.add_argument("--enforce-eager",
|
model_group.add_argument("--enforce-eager",
|
||||||
**model_kwargs["enforce_eager"])
|
**model_kwargs["enforce_eager"])
|
||||||
model_group.add_argument("--max-seq-len-to-capture",
|
|
||||||
**model_kwargs["max_seq_len_to_capture"])
|
|
||||||
model_group.add_argument("--max-logprobs",
|
model_group.add_argument("--max-logprobs",
|
||||||
**model_kwargs["max_logprobs"])
|
**model_kwargs["max_logprobs"])
|
||||||
model_group.add_argument("--logprobs-mode",
|
model_group.add_argument("--logprobs-mode",
|
||||||
@ -1008,7 +1005,6 @@ class EngineArgs:
|
|||||||
max_model_len=self.max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
enforce_eager=self.enforce_eager,
|
enforce_eager=self.enforce_eager,
|
||||||
max_seq_len_to_capture=self.max_seq_len_to_capture,
|
|
||||||
max_logprobs=self.max_logprobs,
|
max_logprobs=self.max_logprobs,
|
||||||
logprobs_mode=self.logprobs_mode,
|
logprobs_mode=self.logprobs_mode,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
|
|||||||
@ -130,11 +130,6 @@ class LLM:
|
|||||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||||
disable CUDA graph and always execute the model in eager mode.
|
disable CUDA graph and always execute the model in eager mode.
|
||||||
If False, we will use CUDA graph and eager execution in hybrid.
|
If False, we will use CUDA graph and eager execution in hybrid.
|
||||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
|
||||||
When a sequence has context length larger than this, we fall back
|
|
||||||
to eager mode. Additionally for encoder-decoder models, if the
|
|
||||||
sequence length of the encoder input is larger than this, we fall
|
|
||||||
back to the eager mode.
|
|
||||||
disable_custom_all_reduce: See
|
disable_custom_all_reduce: See
|
||||||
[ParallelConfig][vllm.config.ParallelConfig].
|
[ParallelConfig][vllm.config.ParallelConfig].
|
||||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||||
@ -184,7 +179,6 @@ class LLM:
|
|||||||
swap_space: float = 4,
|
swap_space: float = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
max_seq_len_to_capture: int = 8192,
|
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
hf_token: Optional[Union[bool, str]] = None,
|
hf_token: Optional[Union[bool, str]] = None,
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
hf_overrides: Optional[HfOverrides] = None,
|
||||||
@ -281,7 +275,6 @@ class LLM:
|
|||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
cpu_offload_gb=cpu_offload_gb,
|
cpu_offload_gb=cpu_offload_gb,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
max_seq_len_to_capture=max_seq_len_to_capture,
|
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
hf_token=hf_token,
|
hf_token=hf_token,
|
||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
|
|||||||
@ -245,19 +245,6 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GraniteMoeHybridModelConfig(VerifyAndUpdateConfig):
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
||||||
config = vllm_config.model_config
|
|
||||||
config.max_seq_len_to_capture = config.max_model_len
|
|
||||||
logger.info(
|
|
||||||
"Setting max_seq_len_to_capture to %d "
|
|
||||||
"to ensure that CUDA graph capture "
|
|
||||||
"covers sequences of length up to max_model_len.",
|
|
||||||
config.max_model_len)
|
|
||||||
|
|
||||||
|
|
||||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -426,7 +413,6 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|||||||
"XLMRobertaModel": JinaRobertaModelConfig,
|
"XLMRobertaModel": JinaRobertaModelConfig,
|
||||||
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
"JinaVLForRanking": JinaVLForSequenceClassificationConfig,
|
||||||
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
"JambaForSequenceClassification": JambaForSequenceClassificationConfig,
|
||||||
"GraniteMoeHybridForCausalLM": GraniteMoeHybridModelConfig,
|
|
||||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||||
"MambaForCausalLM": MambaModelConfig,
|
"MambaForCausalLM": MambaModelConfig,
|
||||||
"Mamba2ForCausalLM": MambaModelConfig,
|
"Mamba2ForCausalLM": MambaModelConfig,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user