diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py index 121c0413e1af..6bdf5ce7c4a6 100644 --- a/tests/entrypoints/openai/test_truncation.py +++ b/tests/entrypoints/openai/test_truncation.py @@ -73,17 +73,11 @@ async def test_zero_truncation_size(client: openai.AsyncOpenAI): "truncate_prompt_tokens": truncation_size } - with pytest.raises(openai.BadRequestError) as err: - await client.post(path="embeddings", cast_to=object, body={**kwargs}) + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) - assert err.value.status_code == 400 - error_details = err.value.response.json()["error"] - - assert error_details["type"] == "BadRequestError" - assert "This model's maximum context length is" in error_details["message"] - assert "tokens in the input for embedding generation" in error_details[ - "message"] - assert "Please reduce the length of the input" in error_details["message"] + assert response["usage"]["prompt_tokens"] == truncation_size @pytest.mark.asyncio diff --git a/tests/entrypoints/test_renderer.py b/tests/entrypoints/test_renderer.py index 54b5271ba67a..1d80ea6cb491 100644 --- a/tests/entrypoints/test_renderer.py +++ b/tests/entrypoints/test_renderer.py @@ -130,6 +130,23 @@ class TestRenderPrompt: assert call_args.kwargs["truncation"] is True assert call_args.kwargs["max_length"] == 50 + @pytest.mark.asyncio + async def test_truncation_negative(self, renderer, mock_async_tokenizer): + # Test that negative truncation uses model's max_model_len + mock_async_tokenizer.return_value = MockTokenizerResult( + [101, 7592, 2088]) # Truncated to max_model_len + renderer.async_tokenizer_pool[ + renderer.tokenizer] = mock_async_tokenizer + + results = await renderer.render_prompt(prompt_or_prompts="Hello world", + max_length=200, + truncate_prompt_tokens=-1) + + assert len(results) == 1 + call_args = mock_async_tokenizer.call_args + assert call_args.kwargs["truncation"] is True + assert call_args.kwargs["max_length"] == 100 # model's max_model_len + @pytest.mark.asyncio async def test_token_truncation_last_elements(self, renderer): # Test that token truncation keeps the last N elements diff --git a/vllm/entrypoints/openai/serving_classification.py b/vllm/entrypoints/openai/serving_classification.py index b4fdc3639031..98b7a206fa0c 100644 --- a/vllm/entrypoints/openai/serving_classification.py +++ b/vllm/entrypoints/openai/serving_classification.py @@ -54,14 +54,11 @@ class ClassificationMixin(OpenAIServing): ctx.tokenizer = await self.engine_client.get_tokenizer( ctx.lora_request) - ( - ctx.request_prompts, - ctx.engine_prompts, - ) = await self._preprocess_completion( - ctx.request, - ctx.tokenizer, - ctx.request.input, - ) + renderer = self._get_renderer(ctx.tokenizer) + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + max_length=self.max_model_len, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens) return None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index c6d3509afda7..c375f9e7c506 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -24,7 +24,6 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (EmbeddingServeContext, OpenAIServing, - RequestPrompt, ServeContext, TextTokensPrompt) # yapf: enable @@ -79,11 +78,12 @@ class EmbeddingMixin(OpenAIServing): tokenizer = await self.engine_client.get_tokenizer(ctx.lora_request ) + renderer = self._get_renderer(tokenizer) if isinstance(ctx.request, EmbeddingChatRequest): ( _, - ctx.request_prompts, + _, ctx.engine_prompts, ) = await self._preprocess_chat( ctx.request, @@ -98,13 +98,18 @@ class EmbeddingMixin(OpenAIServing): add_special_tokens=ctx.request.add_special_tokens, ) else: - (ctx.request_prompts, - ctx.engine_prompts) = await self._preprocess_completion( - ctx.request, - tokenizer, - ctx.request.input, - add_special_tokens=ctx.request.add_special_tokens, - ) + # Set max_length based on chunked processing capability + if self._should_use_chunked_processing(ctx.request): + max_length = None + else: + max_length = self.max_embed_len or self.max_model_len + + ctx.engine_prompts = await renderer.render_prompt( + prompt_or_prompts=ctx.request.input, + max_length=max_length, + truncate_prompt_tokens=ctx.request.truncate_prompt_tokens, + add_special_tokens=ctx.request.add_special_tokens, + ) return None except (ValueError, TypeError) as e: logger.exception("Error in preprocessing prompt inputs") @@ -286,7 +291,6 @@ class EmbeddingMixin(OpenAIServing): self, ctx: EmbeddingServeContext, engine_prompt: Union[EngineTokensPrompt, EngineEmbedsPrompt], - request_prompt: RequestPrompt, pooling_params: PoolingParams, trace_headers: Optional[Mapping[str, str]], prompt_index: int, @@ -295,7 +299,7 @@ class EmbeddingMixin(OpenAIServing): request_id_item = f"{ctx.request_id}-{prompt_index}" self._log_inputs(request_id_item, - request_prompt, + engine_prompt, params=pooling_params, lora_request=ctx.lora_request) @@ -353,20 +357,14 @@ class EmbeddingMixin(OpenAIServing): return self.create_error_response( "Engine prompts not available") - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - max_pos_embeddings = self._get_max_position_embeddings() for i, engine_prompt in enumerate(ctx.engine_prompts): - request_prompt = ctx.request_prompts[i] - # Check if this specific prompt needs chunked processing - if self._is_text_tokens_prompt(request_prompt): + if self._is_text_tokens_prompt(engine_prompt): # Cast to TextTokensPrompt since we've verified # prompt_token_ids - text_tokens_prompt = cast(TextTokensPrompt, request_prompt) + text_tokens_prompt = cast(TextTokensPrompt, engine_prompt) if (len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings): # Use chunked processing for this prompt @@ -382,8 +380,7 @@ class EmbeddingMixin(OpenAIServing): Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt) generator = await self._create_single_prompt_generator( - ctx, engine_prompt_typed, request_prompt, pooling_params, - trace_headers, i) + ctx, engine_prompt_typed, pooling_params, trace_headers, i) generators.append(generator) from vllm.utils import merge_async_iterators @@ -419,10 +416,6 @@ class EmbeddingMixin(OpenAIServing): if not use_chunked: return await super()._collect_batch(ctx=ctx) - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - if ctx.result_generator is None: return self.create_error_response( "Result generator not available") @@ -538,7 +531,7 @@ class EmbeddingMixin(OpenAIServing): data=final_embedding) # Get original prompt token IDs for this prompt - original_prompt = ctx.request_prompts[prompt_idx] + original_prompt = ctx.engine_prompts[prompt_idx] if not self._is_text_tokens_prompt(original_prompt): return self.create_error_response( f"Chunked prompt {prompt_idx} is not a " diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a218f6882f8c..1a2236de4fa4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -368,23 +368,20 @@ class OpenAIServing: for i, engine_prompt in enumerate(ctx.engine_prompts): request_id_item = f"{ctx.request_id}-{i}" - if ctx.request_prompts is None: - return self.create_error_response( - "Request prompts not available") - - self._log_inputs( - request_id_item, - ctx.request_prompts[i], - params=pooling_params, - lora_request=ctx.lora_request, - ) - # Mypy has an existing bug related to inferring the variance of # TypedDicts with `builtins.enumerate`: # https://github.com/python/mypy/issues/8586#issuecomment-2867698435 engine_prompt = cast( Union[EngineTokensPrompt, EngineEmbedsPrompt], engine_prompt) + + self._log_inputs( + request_id_item, + engine_prompt, + params=pooling_params, + lora_request=ctx.lora_request, + ) + generator = self.engine_client.encode( engine_prompt, pooling_params, diff --git a/vllm/entrypoints/renderer.py b/vllm/entrypoints/renderer.py index 29200dda8998..d3f3a8cfa5aa 100644 --- a/vllm/entrypoints/renderer.py +++ b/vllm/entrypoints/renderer.py @@ -108,10 +108,15 @@ class CompletionRenderer(BaseRenderer): for detailed parameter documentation. """ if truncate_prompt_tokens is not None: - if max_length is not None: - assert 0 <= truncate_prompt_tokens <= max_length if truncate_prompt_tokens == 0: return [] + if truncate_prompt_tokens < 0: + truncate_prompt_tokens = self.model_config.max_model_len + if max_length is not None and truncate_prompt_tokens > max_length: + raise ValueError( + f"truncate_prompt_tokens ({truncate_prompt_tokens}) " + f"cannot be greater than max_length ({max_length}). " + f"Please select a smaller truncation size.") # Parse and batch the input prompts batch_inputs = parse_and_batch_prompt(prompt_or_prompts)