Add renderer-based prompt processing for embedding and classification endpoints (#24356)

Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
Flora Feng 2025-09-07 01:26:48 -07:00 committed by GitHub
parent 105d3d62ef
commit 0661cb9df3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 60 additions and 57 deletions

View File

@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
"truncate_prompt_tokens": truncation_size "truncate_prompt_tokens": truncation_size
} }
with pytest.raises(openai.BadRequestError) as err: response = await client.post(path="embeddings",
await client.post(path="embeddings", cast_to=object, body={**kwargs}) cast_to=object,
body={**kwargs})
assert err.value.status_code == 400 assert response["usage"]["prompt_tokens"] == truncation_size
error_details = err.value.response.json()["error"]
assert error_details["type"] == "BadRequestError"
assert "This model's maximum context length is" in error_details["message"]
assert "tokens in the input for embedding generation" in error_details[
"message"]
assert "Please reduce the length of the input" in error_details["message"]
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -130,6 +130,23 @@ class TestRenderPrompt:
assert call_args.kwargs["truncation"] is True assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 50 assert call_args.kwargs["max_length"] == 50
@pytest.mark.asyncio
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
# Test that negative truncation uses model's max_model_len
mock_async_tokenizer.return_value = MockTokenizerResult(
[101, 7592, 2088]) # Truncated to max_model_len
renderer.async_tokenizer_pool[
renderer.tokenizer] = mock_async_tokenizer
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
max_length=200,
truncate_prompt_tokens=-1)
assert len(results) == 1
call_args = mock_async_tokenizer.call_args
assert call_args.kwargs["truncation"] is True
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_token_truncation_last_elements(self, renderer): async def test_token_truncation_last_elements(self, renderer):
# Test that token truncation keeps the last N elements # Test that token truncation keeps the last N elements

View File

@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing):
ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.tokenizer = await self.engine_client.get_tokenizer(
ctx.lora_request) ctx.lora_request)
( renderer = self._get_renderer(ctx.tokenizer)
ctx.request_prompts, ctx.engine_prompts = await renderer.render_prompt(
ctx.engine_prompts, prompt_or_prompts=ctx.request.input,
) = await self._preprocess_completion( max_length=self.max_model_len,
ctx.request, truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
ctx.tokenizer,
ctx.request.input,
)
return None return None

View File

@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
ErrorResponse, UsageInfo) ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
OpenAIServing, OpenAIServing,
RequestPrompt,
ServeContext, ServeContext,
TextTokensPrompt) TextTokensPrompt)
# yapf: enable # yapf: enable
@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
) )
renderer = self._get_renderer(tokenizer)
if isinstance(ctx.request, EmbeddingChatRequest): if isinstance(ctx.request, EmbeddingChatRequest):
( (
_, _,
ctx.request_prompts, _,
ctx.engine_prompts, ctx.engine_prompts,
) = await self._preprocess_chat( ) = await self._preprocess_chat(
ctx.request, ctx.request,
@ -98,11 +98,16 @@ class EmbeddingMixin(OpenAIServing):
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
else: else:
(ctx.request_prompts, # Set max_length based on chunked processing capability
ctx.engine_prompts) = await self._preprocess_completion( if self._should_use_chunked_processing(ctx.request):
ctx.request, max_length = None
tokenizer, else:
ctx.request.input, max_length = self.max_embed_len or self.max_model_len
ctx.engine_prompts = await renderer.render_prompt(
prompt_or_prompts=ctx.request.input,
max_length=max_length,
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
add_special_tokens=ctx.request.add_special_tokens, add_special_tokens=ctx.request.add_special_tokens,
) )
return None return None
@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing):
self, self,
ctx: EmbeddingServeContext, ctx: EmbeddingServeContext,
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
request_prompt: RequestPrompt,
pooling_params: PoolingParams, pooling_params: PoolingParams,
trace_headers: Optional[Mapping[str, str]], trace_headers: Optional[Mapping[str, str]],
prompt_index: int, prompt_index: int,
@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing):
request_id_item = f"{ctx.request_id}-{prompt_index}" request_id_item = f"{ctx.request_id}-{prompt_index}"
self._log_inputs(request_id_item, self._log_inputs(request_id_item,
request_prompt, engine_prompt,
params=pooling_params, params=pooling_params,
lora_request=ctx.lora_request) lora_request=ctx.lora_request)
@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing):
return self.create_error_response( return self.create_error_response(
"Engine prompts not available") "Engine prompts not available")
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
max_pos_embeddings = self._get_max_position_embeddings() max_pos_embeddings = self._get_max_position_embeddings()
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_prompt = ctx.request_prompts[i]
# Check if this specific prompt needs chunked processing # Check if this specific prompt needs chunked processing
if self._is_text_tokens_prompt(request_prompt): if self._is_text_tokens_prompt(engine_prompt):
# Cast to TextTokensPrompt since we've verified # Cast to TextTokensPrompt since we've verified
# prompt_token_ids # prompt_token_ids
text_tokens_prompt = cast(TextTokensPrompt, request_prompt) text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
if (len(text_tokens_prompt["prompt_token_ids"]) if (len(text_tokens_prompt["prompt_token_ids"])
> max_pos_embeddings): > max_pos_embeddings):
# Use chunked processing for this prompt # Use chunked processing for this prompt
@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing):
Union[EngineTokensPrompt, EngineEmbedsPrompt], Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt) engine_prompt)
generator = await self._create_single_prompt_generator( generator = await self._create_single_prompt_generator(
ctx, engine_prompt_typed, request_prompt, pooling_params, ctx, engine_prompt_typed, pooling_params, trace_headers, i)
trace_headers, i)
generators.append(generator) generators.append(generator)
from vllm.utils import merge_async_iterators from vllm.utils import merge_async_iterators
@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing):
if not use_chunked: if not use_chunked:
return await super()._collect_batch(ctx=ctx) return await super()._collect_batch(ctx=ctx)
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
if ctx.result_generator is None: if ctx.result_generator is None:
return self.create_error_response( return self.create_error_response(
"Result generator not available") "Result generator not available")
@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing):
data=final_embedding) data=final_embedding)
# Get original prompt token IDs for this prompt # Get original prompt token IDs for this prompt
original_prompt = ctx.request_prompts[prompt_idx] original_prompt = ctx.engine_prompts[prompt_idx]
if not self._is_text_tokens_prompt(original_prompt): if not self._is_text_tokens_prompt(original_prompt):
return self.create_error_response( return self.create_error_response(
f"Chunked prompt {prompt_idx} is not a " f"Chunked prompt {prompt_idx} is not a "

View File

@ -368,23 +368,20 @@ class OpenAIServing:
for i, engine_prompt in enumerate(ctx.engine_prompts): for i, engine_prompt in enumerate(ctx.engine_prompts):
request_id_item = f"{ctx.request_id}-{i}" request_id_item = f"{ctx.request_id}-{i}"
if ctx.request_prompts is None:
return self.create_error_response(
"Request prompts not available")
self._log_inputs(
request_id_item,
ctx.request_prompts[i],
params=pooling_params,
lora_request=ctx.lora_request,
)
# Mypy has an existing bug related to inferring the variance of # Mypy has an existing bug related to inferring the variance of
# TypedDicts with `builtins.enumerate`: # TypedDicts with `builtins.enumerate`:
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435 # https://github.com/python/mypy/issues/8586#issuecomment-2867698435
engine_prompt = cast( engine_prompt = cast(
Union[EngineTokensPrompt, EngineEmbedsPrompt], Union[EngineTokensPrompt, EngineEmbedsPrompt],
engine_prompt) engine_prompt)
self._log_inputs(
request_id_item,
engine_prompt,
params=pooling_params,
lora_request=ctx.lora_request,
)
generator = self.engine_client.encode( generator = self.engine_client.encode(
engine_prompt, engine_prompt,
pooling_params, pooling_params,

View File

@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer):
for detailed parameter documentation. for detailed parameter documentation.
""" """
if truncate_prompt_tokens is not None: if truncate_prompt_tokens is not None:
if max_length is not None:
assert 0 <= truncate_prompt_tokens <= max_length
if truncate_prompt_tokens == 0: if truncate_prompt_tokens == 0:
return [] return []
if truncate_prompt_tokens < 0:
truncate_prompt_tokens = self.model_config.max_model_len
if max_length is not None and truncate_prompt_tokens > max_length:
raise ValueError(
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
f"cannot be greater than max_length ({max_length}). "
f"Please select a smaller truncation size.")
# Parse and batch the input prompts # Parse and batch the input prompts
batch_inputs = parse_and_batch_prompt(prompt_or_prompts) batch_inputs = parse_and_batch_prompt(prompt_or_prompts)