[CI] Expand OpenAI test_chat.py guided decoding tests (#11048)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2024-12-23 13:35:38 -05:00 committed by GitHub
parent 8cef6e02dc
commit 63afbe9215
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -17,6 +17,8 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here # any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811
@ -464,8 +466,7 @@ async def test_chat_completion_stream_options(client: openai.AsyncOpenAI,
# will fail on the second `guided_decoding_backend` even when I swap their order # will fail on the second `guided_decoding_backend` even when I swap their order
# (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256) # (ref: https://github.com/vllm-project/vllm/pull/5526#issuecomment-2173772256)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(client: openai.AsyncOpenAI, async def test_guided_choice_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
@ -506,8 +507,7 @@ async def test_guided_choice_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(client: openai.AsyncOpenAI, async def test_guided_json_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
@ -554,8 +554,7 @@ async def test_guided_json_chat(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(client: openai.AsyncOpenAI, async def test_guided_regex_chat(client: openai.AsyncOpenAI,
guided_decoding_backend: str, sample_regex): guided_decoding_backend: str, sample_regex):
messages = [{ messages = [{
@ -613,8 +612,7 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_guided_choice): sample_guided_choice):
@ -646,8 +644,7 @@ async def test_guided_choice_chat_logprobs(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
["outlines", "lm-format-enforcer"])
async def test_named_tool_use(client: openai.AsyncOpenAI, async def test_named_tool_use(client: openai.AsyncOpenAI,
guided_decoding_backend: str, guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
@ -681,7 +678,8 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"function": { "function": {
"name": "dummy_function_name" "name": "dummy_function_name"
} }
}) },
extra_body=dict(guided_decoding_backend=guided_decoding_backend))
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert len(message.content) == 0 assert len(message.content) == 0
json_string = message.tool_calls[0].function.arguments json_string = message.tool_calls[0].function.arguments
@ -716,6 +714,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
"name": "dummy_function_name" "name": "dummy_function_name"
} }
}, },
extra_body=dict(guided_decoding_backend=guided_decoding_backend),
stream=True) stream=True)
output = [] output = []
@ -738,9 +737,7 @@ async def test_named_tool_use(client: openai.AsyncOpenAI,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"]) async def test_required_tool_use_not_yet_supported(client: openai.AsyncOpenAI,
async def test_required_tool_use_not_yet_supported(
client: openai.AsyncOpenAI, guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",
@ -785,9 +782,7 @@ async def test_required_tool_use_not_yet_supported(
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI,
guided_decoding_backend: str,
sample_json_schema): sample_json_schema):
messages = [{ messages = [{
"role": "system", "role": "system",