From 6c11ecf8d337d3b89e891588c81640a0bd30f6e1 Mon Sep 17 00:00:00 2001 From: Ryan McConville Date: Sat, 12 Apr 2025 21:19:19 +0100 Subject: [PATCH] [Bugfix] Validate logit biases to prevent out of vocab ids crashing engine (#16529) Signed-off-by: Ryan McConville --- .../openai/test_chat_logit_bias_validation.py | 88 +++++++++++++++++++ vllm/v1/engine/processor.py | 21 +++++ vllm/v1/sample/sampler.py | 10 +++ 3 files changed, 119 insertions(+) create mode 100644 tests/entrypoints/openai/test_chat_logit_bias_validation.py diff --git a/tests/entrypoints/openai/test_chat_logit_bias_validation.py b/tests/entrypoints/openai/test_chat_logit_bias_validation.py new file mode 100644 index 0000000000000..9dab524ea4801 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_logit_bias_validation.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai +import pytest +import pytest_asyncio + +from vllm.config import ModelConfig + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +def get_vocab_size(model_name): + config = ModelConfig( + model=model_name, + task="auto", + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="bfloat16", + ) + return config.get_vocab_size() + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "1024", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_chat_logit_bias_valid(client): + """Test that valid logit_bias values are accepted in chat completions.""" + vocab_size = get_vocab_size(MODEL_NAME) + valid_token_id = vocab_size - 1 + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "Testing valid logit bias" + }], + max_tokens=5, + logit_bias={str(valid_token_id): 1.0}, + ) + + assert completion.choices[0].message.content is not None + + +@pytest.mark.asyncio +async def test_chat_logit_bias_invalid(client): + """Test that invalid logit_bias values are rejected in chat completions.""" + vocab_size = get_vocab_size(MODEL_NAME) + invalid_token_id = vocab_size + 1 + + with pytest.raises(openai.BadRequestError) as excinfo: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "Testing invalid logit bias" + }], + max_tokens=5, + logit_bias={str(invalid_token_id): 1.0}, + ) + + error = excinfo.value + error_message = str(error) + + assert error.status_code == 400 + assert str(invalid_token_id) in error_message + assert str(vocab_size) in error_message diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7d1913ecebed2..6d3290f165653 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -77,6 +77,7 @@ class Processor: params: SamplingParams, ) -> None: self._validate_structured_output(params) + self._validate_logit_bias(params) if params.allowed_token_ids is None: return @@ -87,6 +88,26 @@ class Processor: raise ValueError( "allowed_token_ids contains out-of-vocab token id!") + def _validate_logit_bias( + self, + params: SamplingParams, + ) -> None: + """Validate logit_bias token IDs are within vocabulary range.""" + if not params.logit_bias: + return + + vocab_size = self.model_config.get_vocab_size() + invalid_token_ids = [] + + for token_id in params.logit_bias: + if token_id < 0 or token_id >= vocab_size: + invalid_token_ids.append(token_id) + + if invalid_token_ids: + raise ValueError( + f"token_id(s) {invalid_token_ids} in logit_bias contain " + f"out-of-vocab token ids. Vocabulary size: {vocab_size}") + def _validate_supported_sampling_params( self, params: SamplingParams, diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 004f98496b0d7..16561d30a6dc3 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -230,9 +230,19 @@ class Sampler(nn.Module): # TODO(houseroad): this implementation is extremely inefficient. # One idea is implement this as a PyTorch C++ op, and we may # even optimize the logit_bias layout. + + # Get vocabulary size from logits + vocab_size = logits.shape[-1] + for i, logit_bias in enumerate(sampling_metadata.logit_bias): if logit_bias: for token_id, bias in logit_bias.items(): + # Check token_id bounds to ensure within vocabulary + if token_id < 0 or token_id >= vocab_size: + raise ValueError( + f"token_id {token_id} in logit_bias contains " + f"out-of-vocab token id. Vocabulary size: " + f"{vocab_size}") logits[i, token_id] += bias return logits