[Core] Gate prompt_embeds behind a feature flag (#17607)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-04 00:19:20 +08:00 committed by GitHub
parent a92842454c
commit 887d7af882
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 84 additions and 33 deletions

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
skip_tokenizer_init=True,
enforce_eager=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_enable_prompt_embeds(hf_runner, model: str,
enable_prompt_embeds: bool):
prompt = "abc"
with hf_runner(model) as hf_model:
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
token_ids = token_ids.to(hf_model.model.device)
embed_layer = hf_model.model.get_input_embeddings()
prompt_embeds = embed_layer(token_ids).squeeze(0)
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
ValueError, match="set `--enable-prompt-embeds`"))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
enable_prompt_embeds=enable_prompt_embeds,
enforce_eager=True,
)
with ctx:
llm.generate({"prompt_embeds": prompt_embeds})

View File

@ -1,29 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
skip_tokenizer_init=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids

View File

@ -109,12 +109,15 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators # in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.") pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
"VLLM_USE_V1") == "0" else None else None)
prompt_token_ids = [] prompt_token_ids = []
for prompt in example_prompts: for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt, token_ids = hf_model.tokenizer(prompt,
@ -131,6 +134,7 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
tokenizer_mode=model_info.tokenizer_mode, tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
max_num_seqs=2, max_num_seqs=2,
enable_prompt_embeds=use_prompt_embeds,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)

View File

@ -43,6 +43,7 @@ def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=False, enable_chunked_prefill=False,
enable_prompt_embeds=True,
) )
seq_lens: list[int] = [] seq_lens: list[int] = []
@ -179,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=False, enable_chunked_prefill=False,
enable_prompt_embeds=True,
) )
context_lens: list[int] = [] context_lens: list[int] = []
@ -359,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
max_num_seqs=100000, max_num_seqs=100000,
enable_chunked_prefill=True, enable_chunked_prefill=True,
enable_prompt_embeds=True,
) )
# Add prefill requests. # Add prefill requests.

View File

@ -321,6 +321,10 @@ class ModelConfig:
"""Skip initialization of tokenizer and detokenizer. Expects valid """Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated `prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids.""" output will contain token ids."""
enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
served_model_name: Optional[Union[str, list[str]]] = None served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the """The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the server will respond to any of the provided names. The model name in the

View File

@ -234,6 +234,7 @@ class EngineArgs:
hf_config_path: Optional[str] = ModelConfig.hf_config_path hf_config_path: Optional[str] = ModelConfig.hf_config_path
task: TaskOption = ModelConfig.task task: TaskOption = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
@ -445,6 +446,8 @@ class EngineArgs:
**model_kwargs["disable_cascade_attn"]) **model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init", model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"]) **model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--enable-prompt-embeds",
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name", model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"]) **model_kwargs["served_model_name"])
# This one is a special case because it is the # This one is a special case because it is the
@ -874,6 +877,7 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn, disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init, skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,

View File

@ -303,8 +303,11 @@ class InputPreprocessor:
self, self,
parsed_content: EmbedsPrompt, parsed_content: EmbedsPrompt,
) -> EmbedsInputs: ) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError("You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`.")
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.") raise ValueError("`prompt_embeds` is only available in V0.")
prompt_embeds = parsed_content["prompt_embeds"] prompt_embeds = parsed_content["prompt_embeds"]

View File

@ -1565,7 +1565,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# product. # product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\ cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes .cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False) cudagraph_inputs_embeds = ((
True, False) if self.model_config.enable_prompt_embeds else
(False, ))
compilation_cases = itertools.product( compilation_cases = itertools.product(
cudagraph_capture_sizes, cudagraph_capture_sizes,
cudagraph_inputs_embeds, cudagraph_inputs_embeds,