mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 00:55:40 +08:00
Add renderer-based prompt processing for embedding and classification endpoints (#24356)
Signed-off-by: sfeng33 <4florafeng@gmail.com>
This commit is contained in:
parent
105d3d62ef
commit
0661cb9df3
@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI):
|
|||||||
"truncate_prompt_tokens": truncation_size
|
"truncate_prompt_tokens": truncation_size
|
||||||
}
|
}
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError) as err:
|
response = await client.post(path="embeddings",
|
||||||
await client.post(path="embeddings", cast_to=object, body={**kwargs})
|
cast_to=object,
|
||||||
|
body={**kwargs})
|
||||||
|
|
||||||
assert err.value.status_code == 400
|
assert response["usage"]["prompt_tokens"] == truncation_size
|
||||||
error_details = err.value.response.json()["error"]
|
|
||||||
|
|
||||||
assert error_details["type"] == "BadRequestError"
|
|
||||||
assert "This model's maximum context length is" in error_details["message"]
|
|
||||||
assert "tokens in the input for embedding generation" in error_details[
|
|
||||||
"message"]
|
|
||||||
assert "Please reduce the length of the input" in error_details["message"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -130,6 +130,23 @@ class TestRenderPrompt:
|
|||||||
assert call_args.kwargs["truncation"] is True
|
assert call_args.kwargs["truncation"] is True
|
||||||
assert call_args.kwargs["max_length"] == 50
|
assert call_args.kwargs["max_length"] == 50
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_negative(self, renderer, mock_async_tokenizer):
|
||||||
|
# Test that negative truncation uses model's max_model_len
|
||||||
|
mock_async_tokenizer.return_value = MockTokenizerResult(
|
||||||
|
[101, 7592, 2088]) # Truncated to max_model_len
|
||||||
|
renderer.async_tokenizer_pool[
|
||||||
|
renderer.tokenizer] = mock_async_tokenizer
|
||||||
|
|
||||||
|
results = await renderer.render_prompt(prompt_or_prompts="Hello world",
|
||||||
|
max_length=200,
|
||||||
|
truncate_prompt_tokens=-1)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
call_args = mock_async_tokenizer.call_args
|
||||||
|
assert call_args.kwargs["truncation"] is True
|
||||||
|
assert call_args.kwargs["max_length"] == 100 # model's max_model_len
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_token_truncation_last_elements(self, renderer):
|
async def test_token_truncation_last_elements(self, renderer):
|
||||||
# Test that token truncation keeps the last N elements
|
# Test that token truncation keeps the last N elements
|
||||||
|
|||||||
@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing):
|
|||||||
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
ctx.tokenizer = await self.engine_client.get_tokenizer(
|
||||||
ctx.lora_request)
|
ctx.lora_request)
|
||||||
|
|
||||||
(
|
renderer = self._get_renderer(ctx.tokenizer)
|
||||||
ctx.request_prompts,
|
ctx.engine_prompts = await renderer.render_prompt(
|
||||||
ctx.engine_prompts,
|
prompt_or_prompts=ctx.request.input,
|
||||||
) = await self._preprocess_completion(
|
max_length=self.max_model_len,
|
||||||
ctx.request,
|
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens)
|
||||||
ctx.tokenizer,
|
|
||||||
ctx.request.input,
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
|
|||||||
ErrorResponse, UsageInfo)
|
ErrorResponse, UsageInfo)
|
||||||
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext,
|
||||||
OpenAIServing,
|
OpenAIServing,
|
||||||
RequestPrompt,
|
|
||||||
ServeContext,
|
ServeContext,
|
||||||
TextTokensPrompt)
|
TextTokensPrompt)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
|
|
||||||
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request
|
||||||
)
|
)
|
||||||
|
renderer = self._get_renderer(tokenizer)
|
||||||
|
|
||||||
if isinstance(ctx.request, EmbeddingChatRequest):
|
if isinstance(ctx.request, EmbeddingChatRequest):
|
||||||
(
|
(
|
||||||
_,
|
_,
|
||||||
ctx.request_prompts,
|
_,
|
||||||
ctx.engine_prompts,
|
ctx.engine_prompts,
|
||||||
) = await self._preprocess_chat(
|
) = await self._preprocess_chat(
|
||||||
ctx.request,
|
ctx.request,
|
||||||
@ -98,11 +98,16 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
add_special_tokens=ctx.request.add_special_tokens,
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
(ctx.request_prompts,
|
# Set max_length based on chunked processing capability
|
||||||
ctx.engine_prompts) = await self._preprocess_completion(
|
if self._should_use_chunked_processing(ctx.request):
|
||||||
ctx.request,
|
max_length = None
|
||||||
tokenizer,
|
else:
|
||||||
ctx.request.input,
|
max_length = self.max_embed_len or self.max_model_len
|
||||||
|
|
||||||
|
ctx.engine_prompts = await renderer.render_prompt(
|
||||||
|
prompt_or_prompts=ctx.request.input,
|
||||||
|
max_length=max_length,
|
||||||
|
truncate_prompt_tokens=ctx.request.truncate_prompt_tokens,
|
||||||
add_special_tokens=ctx.request.add_special_tokens,
|
add_special_tokens=ctx.request.add_special_tokens,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
self,
|
self,
|
||||||
ctx: EmbeddingServeContext,
|
ctx: EmbeddingServeContext,
|
||||||
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
request_prompt: RequestPrompt,
|
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
trace_headers: Optional[Mapping[str, str]],
|
trace_headers: Optional[Mapping[str, str]],
|
||||||
prompt_index: int,
|
prompt_index: int,
|
||||||
@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
request_id_item = f"{ctx.request_id}-{prompt_index}"
|
||||||
|
|
||||||
self._log_inputs(request_id_item,
|
self._log_inputs(request_id_item,
|
||||||
request_prompt,
|
engine_prompt,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=ctx.lora_request)
|
lora_request=ctx.lora_request)
|
||||||
|
|
||||||
@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Engine prompts not available")
|
"Engine prompts not available")
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
max_pos_embeddings = self._get_max_position_embeddings()
|
max_pos_embeddings = self._get_max_position_embeddings()
|
||||||
|
|
||||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||||
request_prompt = ctx.request_prompts[i]
|
|
||||||
|
|
||||||
# Check if this specific prompt needs chunked processing
|
# Check if this specific prompt needs chunked processing
|
||||||
if self._is_text_tokens_prompt(request_prompt):
|
if self._is_text_tokens_prompt(engine_prompt):
|
||||||
# Cast to TextTokensPrompt since we've verified
|
# Cast to TextTokensPrompt since we've verified
|
||||||
# prompt_token_ids
|
# prompt_token_ids
|
||||||
text_tokens_prompt = cast(TextTokensPrompt, request_prompt)
|
text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
|
||||||
if (len(text_tokens_prompt["prompt_token_ids"])
|
if (len(text_tokens_prompt["prompt_token_ids"])
|
||||||
> max_pos_embeddings):
|
> max_pos_embeddings):
|
||||||
# Use chunked processing for this prompt
|
# Use chunked processing for this prompt
|
||||||
@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
engine_prompt)
|
engine_prompt)
|
||||||
generator = await self._create_single_prompt_generator(
|
generator = await self._create_single_prompt_generator(
|
||||||
ctx, engine_prompt_typed, request_prompt, pooling_params,
|
ctx, engine_prompt_typed, pooling_params, trace_headers, i)
|
||||||
trace_headers, i)
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
|
|
||||||
from vllm.utils import merge_async_iterators
|
from vllm.utils import merge_async_iterators
|
||||||
@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
if not use_chunked:
|
if not use_chunked:
|
||||||
return await super()._collect_batch(ctx=ctx)
|
return await super()._collect_batch(ctx=ctx)
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
if ctx.result_generator is None:
|
if ctx.result_generator is None:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
"Result generator not available")
|
"Result generator not available")
|
||||||
@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing):
|
|||||||
data=final_embedding)
|
data=final_embedding)
|
||||||
|
|
||||||
# Get original prompt token IDs for this prompt
|
# Get original prompt token IDs for this prompt
|
||||||
original_prompt = ctx.request_prompts[prompt_idx]
|
original_prompt = ctx.engine_prompts[prompt_idx]
|
||||||
if not self._is_text_tokens_prompt(original_prompt):
|
if not self._is_text_tokens_prompt(original_prompt):
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
f"Chunked prompt {prompt_idx} is not a "
|
f"Chunked prompt {prompt_idx} is not a "
|
||||||
|
|||||||
@ -368,23 +368,20 @@ class OpenAIServing:
|
|||||||
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
for i, engine_prompt in enumerate(ctx.engine_prompts):
|
||||||
request_id_item = f"{ctx.request_id}-{i}"
|
request_id_item = f"{ctx.request_id}-{i}"
|
||||||
|
|
||||||
if ctx.request_prompts is None:
|
|
||||||
return self.create_error_response(
|
|
||||||
"Request prompts not available")
|
|
||||||
|
|
||||||
self._log_inputs(
|
|
||||||
request_id_item,
|
|
||||||
ctx.request_prompts[i],
|
|
||||||
params=pooling_params,
|
|
||||||
lora_request=ctx.lora_request,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mypy has an existing bug related to inferring the variance of
|
# Mypy has an existing bug related to inferring the variance of
|
||||||
# TypedDicts with `builtins.enumerate`:
|
# TypedDicts with `builtins.enumerate`:
|
||||||
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
# https://github.com/python/mypy/issues/8586#issuecomment-2867698435
|
||||||
engine_prompt = cast(
|
engine_prompt = cast(
|
||||||
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
Union[EngineTokensPrompt, EngineEmbedsPrompt],
|
||||||
engine_prompt)
|
engine_prompt)
|
||||||
|
|
||||||
|
self._log_inputs(
|
||||||
|
request_id_item,
|
||||||
|
engine_prompt,
|
||||||
|
params=pooling_params,
|
||||||
|
lora_request=ctx.lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
generator = self.engine_client.encode(
|
generator = self.engine_client.encode(
|
||||||
engine_prompt,
|
engine_prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
|
|||||||
@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer):
|
|||||||
for detailed parameter documentation.
|
for detailed parameter documentation.
|
||||||
"""
|
"""
|
||||||
if truncate_prompt_tokens is not None:
|
if truncate_prompt_tokens is not None:
|
||||||
if max_length is not None:
|
|
||||||
assert 0 <= truncate_prompt_tokens <= max_length
|
|
||||||
if truncate_prompt_tokens == 0:
|
if truncate_prompt_tokens == 0:
|
||||||
return []
|
return []
|
||||||
|
if truncate_prompt_tokens < 0:
|
||||||
|
truncate_prompt_tokens = self.model_config.max_model_len
|
||||||
|
if max_length is not None and truncate_prompt_tokens > max_length:
|
||||||
|
raise ValueError(
|
||||||
|
f"truncate_prompt_tokens ({truncate_prompt_tokens}) "
|
||||||
|
f"cannot be greater than max_length ({max_length}). "
|
||||||
|
f"Please select a smaller truncation size.")
|
||||||
|
|
||||||
# Parse and batch the input prompts
|
# Parse and batch the input prompts
|
||||||
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
batch_inputs = parse_and_batch_prompt(prompt_or_prompts)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user