mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-19 06:35:01 +08:00
fix: use cache_salt for gpt-oss (#23186)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
parent
b94faf9d50
commit
80141bbf2f
@ -282,9 +282,11 @@ async def test_serving_chat_could_load_correct_generation_config():
|
|||||||
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_type", ["gpt_oss", "any"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_serving_chat_did_set_correct_cache_salt():
|
async def test_serving_chat_did_set_correct_cache_salt(model_type):
|
||||||
mock_model_config = MockModelConfig()
|
mock_model_config = MockModelConfig()
|
||||||
|
mock_model_config.hf_config.model_type = model_type
|
||||||
|
|
||||||
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
mock_engine = MagicMock(spec=MQLLMEngineClient)
|
||||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||||
|
|||||||
@ -1483,4 +1483,9 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
# Render prompt token ids.
|
# Render prompt token ids.
|
||||||
prompt_token_ids = render_for_completion(messages)
|
prompt_token_ids = render_for_completion(messages)
|
||||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||||
|
|
||||||
|
# Add cache_salt if provided in the request
|
||||||
|
if request.cache_salt is not None:
|
||||||
|
engine_prompt["cache_salt"] = request.cache_salt
|
||||||
|
|
||||||
return messages, [prompt_token_ids], [engine_prompt]
|
return messages, [prompt_token_ids], [engine_prompt]
|
||||||
|
|||||||
@ -408,6 +408,11 @@ class OpenAIServingResponses(OpenAIServing):
|
|||||||
request, prev_response)
|
request, prev_response)
|
||||||
prompt_token_ids = render_for_completion(messages)
|
prompt_token_ids = render_for_completion(messages)
|
||||||
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
engine_prompt = EngineTokensPrompt(prompt_token_ids=prompt_token_ids)
|
||||||
|
|
||||||
|
# Add cache_salt if provided in the request
|
||||||
|
if request.cache_salt is not None:
|
||||||
|
engine_prompt["cache_salt"] = request.cache_salt
|
||||||
|
|
||||||
return messages, [prompt_token_ids], [engine_prompt]
|
return messages, [prompt_token_ids], [engine_prompt]
|
||||||
|
|
||||||
async def responses_full_generator(
|
async def responses_full_generator(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user