[CI/Build] Test both text and token IDs in batched OpenAI Completions API (#5568)

This commit is contained in:
Cyrus Leung 2024-06-15 19:29:42 +08:00 committed by GitHub
parent 0e9164b40a
commit 81fbb3655f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -655,50 +655,52 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
[MODEL_NAME, "zephyr-lora"], [MODEL_NAME, "zephyr-lora"],
) )
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str): async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test simple list # test both text and token IDs
batch = await client.completions.create( for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
model=model_name, # test simple list
prompt=["Hello, my name is", "Hello, my name is"], batch = await client.completions.create(
max_tokens=5, model=model_name,
temperature=0.0, prompt=prompts,
) max_tokens=5,
assert len(batch.choices) == 2 temperature=0.0,
assert batch.choices[0].text == batch.choices[1].text )
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2 # test n = 2
batch = await client.completions.create( batch = await client.completions.create(
model=model_name, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=prompts,
n=2, n=2,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
extra_body=dict( extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but not necessary # NOTE: this has to be true for n > 1 in vLLM, but not necessary
# for official client. # for official client.
use_beam_search=True), use_beam_search=True),
) )
assert len(batch.choices) == 4 assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[ assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different" 1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[ assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same" 2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[ assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same" 3].text, "two copies of the same prompt should be the same"
# test streaming # test streaming
batch = await client.completions.create( batch = await client.completions.create(
model=model_name, model=model_name,
prompt=["Hello, my name is", "Hello, my name is"], prompt=prompts,
max_tokens=5, max_tokens=5,
temperature=0.0, temperature=0.0,
stream=True, stream=True,
) )
texts = [""] * 2 texts = [""] * 2
async for chunk in batch: async for chunk in batch:
assert len(chunk.choices) == 1 assert len(chunk.choices) == 1
choice = chunk.choices[0] choice = chunk.choices[0]
texts[choice.index] += choice.text texts[choice.index] += choice.text
assert texts[0] == texts[1] assert texts[0] == texts[1]
@pytest.mark.asyncio @pytest.mark.asyncio