From 80141bbf2f1b8b0beaac097f94923f95773734ef Mon Sep 17 00:00:00 2001 From: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> Date: Tue, 19 Aug 2025 20:12:25 +0200 Subject: [PATCH] fix: use cache_salt for gpt-oss (#23186) Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com> --- tests/entrypoints/openai/test_serving_chat.py | 4 +++- vllm/entrypoints/openai/serving_chat.py | 5 +++++ vllm/entrypoints/openai/serving_responses.py | 5 +++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 8a7892cf6d6aa..10879f0be83c8 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -282,9 +282,11 @@ async def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 +@pytest.mark.parametrize("model_type", ["gpt_oss", "any"]) @pytest.mark.asyncio -async def test_serving_chat_did_set_correct_cache_salt(): +async def test_serving_chat_did_set_correct_cache_salt(model_type): mock_model_config = MockModelConfig() + mock_model_config.hf_config.model_type = model_type mock_engine = MagicMock(spec=MQLLMEngineClient) mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1789521afc84c..d57868847eedd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1483,4 +1483,9 @@ class OpenAIServingChat(OpenAIServing): # Render prompt token ids. prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 86c16df40e693..1b30fa01ea91f 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -408,6 +408,11 @@ class OpenAIServingResponses(OpenAIServing): request, prev_response) prompt_token_ids = render_for_completion(messages) engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids) + + # Add cache_salt if provided in the request + if request.cache_salt is not None: + engine_prompt["cache_salt"] = request.cache_salt + return messages, [prompt_token_ids], [engine_prompt] async def responses_full_generator(