mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-21 02:45:01 +08:00
[Bugfix][Frontend] Cleanup "fix chat logprobs" (#5026)
This commit is contained in:
parent
351d5e7b82
commit
640052b069
@ -55,9 +55,8 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 1
|
assert len(completion.choices) == 1
|
||||||
assert completion.choices[0].text is not None and len(
|
assert len(completion.choices[0].text) >= 5
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
assert completion.choices[0].finish_reason == "length"
|
assert completion.choices[0].finish_reason == "length"
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
@ -69,8 +68,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI):
|
|||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
assert completion.choices[0].text is not None and len(
|
assert len(completion.choices[0].text) >= 5
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -90,15 +88,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
|||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=5)
|
top_logprobs=5)
|
||||||
assert chat_completion.id is not None
|
assert chat_completion.id is not None
|
||||||
assert chat_completion.choices is not None and len(
|
assert len(chat_completion.choices) == 1
|
||||||
chat_completion.choices) == 1
|
|
||||||
assert chat_completion.choices[0].message is not None
|
choice = chat_completion.choices[0]
|
||||||
assert chat_completion.choices[0].logprobs is not None
|
assert choice.finish_reason == "length"
|
||||||
assert chat_completion.choices[0].logprobs.content[
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
0].top_logprobs is not None
|
completion_tokens=10, prompt_tokens=13, total_tokens=23)
|
||||||
assert len(
|
|
||||||
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
message = choice.message
|
||||||
message = chat_completion.choices[0].message
|
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
assert message.role == "assistant"
|
assert message.role == "assistant"
|
||||||
messages.append({"role": "assistant", "content": message.content})
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|||||||
@ -167,9 +167,10 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 1
|
assert completion.choices is not None and len(completion.choices) == 1
|
||||||
assert completion.choices[0].text is not None and len(
|
|
||||||
completion.choices[0].text) >= 5
|
choice = completion.choices[0]
|
||||||
assert completion.choices[0].finish_reason == "length"
|
assert len(choice.text) >= 5
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
|
|
||||||
@ -180,8 +181,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
assert completion.choices[0].text is not None and len(
|
assert len(completion.choices[0].text) >= 5
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -206,9 +206,9 @@ async def test_no_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# first test base model, then test loras
|
# just test 1 lora hereafter
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
@ -291,55 +291,7 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
)
|
)
|
||||||
completion = completion.choices[0].text
|
assert len(completion.choices[0].text) >= 0
|
||||||
assert completion is not None and len(completion) >= 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
# just test 1 lora hereafter
|
|
||||||
"model_name",
|
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
|
||||||
)
|
|
||||||
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
|
||||||
model_name: str):
|
|
||||||
messages = [{
|
|
||||||
"role": "system",
|
|
||||||
"content": "you are a helpful assistant"
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "what is 1+1?"
|
|
||||||
}]
|
|
||||||
|
|
||||||
# test single completion
|
|
||||||
chat_completion = await client.chat.completions.create(model=model_name,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
logprobs=True,
|
|
||||||
top_logprobs=5)
|
|
||||||
assert chat_completion.id is not None
|
|
||||||
assert chat_completion.choices is not None and len(
|
|
||||||
chat_completion.choices) == 1
|
|
||||||
assert chat_completion.choices[0].message is not None
|
|
||||||
assert chat_completion.choices[0].logprobs is not None
|
|
||||||
assert chat_completion.choices[0].logprobs.content[
|
|
||||||
0].top_logprobs is not None
|
|
||||||
assert len(
|
|
||||||
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
|
||||||
message = chat_completion.choices[0].message
|
|
||||||
assert message.content is not None and len(message.content) >= 10
|
|
||||||
assert message.role == "assistant"
|
|
||||||
messages.append({"role": "assistant", "content": message.content})
|
|
||||||
|
|
||||||
# test multi-turn dialogue
|
|
||||||
messages.append({"role": "user", "content": "express your result in json"})
|
|
||||||
chat_completion = await client.chat.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
messages=messages,
|
|
||||||
max_tokens=10,
|
|
||||||
)
|
|
||||||
message = chat_completion.choices[0].message
|
|
||||||
assert message.content is not None and len(message.content) >= 0
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -394,7 +346,7 @@ async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.logprobs is not None
|
assert choice.logprobs is not None
|
||||||
assert choice.logprobs.content is not None
|
assert choice.logprobs.content is not None
|
||||||
assert len(choice.logprobs.content[0].top_logprobs) <= 1
|
assert len(choice.logprobs.content[0].top_logprobs) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -422,11 +374,14 @@ async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
|
|||||||
choice = chat_completion.choices[0]
|
choice = chat_completion.choices[0]
|
||||||
assert choice.logprobs is not None
|
assert choice.logprobs is not None
|
||||||
assert choice.logprobs.content is not None
|
assert choice.logprobs.content is not None
|
||||||
assert len(choice.logprobs.content[0].top_logprobs) <= 6
|
assert len(choice.logprobs.content[0].top_logprobs) == 5
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
messages = [{
|
messages = [{
|
||||||
@ -467,7 +422,51 @@ async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# just test 1 lora hereafter
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# test single completion
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
assert chat_completion.id is not None
|
||||||
|
assert len(chat_completion.choices) == 1
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.finish_reason == "length"
|
||||||
|
assert chat_completion.usage == openai.types.CompletionUsage(
|
||||||
|
completion_tokens=10, prompt_tokens=37, total_tokens=47)
|
||||||
|
|
||||||
|
message = choice.message
|
||||||
|
assert message.content is not None and len(message.content) >= 10
|
||||||
|
assert message.role == "assistant"
|
||||||
|
messages.append({"role": "assistant", "content": message.content})
|
||||||
|
|
||||||
|
# test multi-turn dialogue
|
||||||
|
messages.append({"role": "user", "content": "express your result in json"})
|
||||||
|
chat_completion = await client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=10,
|
||||||
|
)
|
||||||
|
message = chat_completion.choices[0].message
|
||||||
|
assert message.content is not None and len(message.content) >= 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
"model_name",
|
"model_name",
|
||||||
[MODEL_NAME, "zephyr-lora"],
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
)
|
)
|
||||||
@ -753,8 +752,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
|||||||
logit_bias={str(token_id): 100},
|
logit_bias={str(token_id): 100},
|
||||||
seed=42,
|
seed=42,
|
||||||
)
|
)
|
||||||
assert completion.choices[0].text is not None and len(
|
assert len(completion.choices[0].text) >= 5
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
response_tokens = tokenizer(completion.choices[0].text,
|
response_tokens = tokenizer(completion.choices[0].text,
|
||||||
add_special_tokens=False)["input_ids"]
|
add_special_tokens=False)["input_ids"]
|
||||||
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
||||||
@ -801,9 +799,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
|||||||
guided_decoding_backend=guided_decoding_backend))
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 3
|
assert len(completion.choices) == 3
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
assert completion.choices[i].text is not None
|
|
||||||
output_json = json.loads(completion.choices[i].text)
|
output_json = json.loads(completion.choices[i].text)
|
||||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||||
|
|
||||||
@ -870,9 +867,8 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
|||||||
guided_decoding_backend=guided_decoding_backend))
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 3
|
assert len(completion.choices) == 3
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
assert completion.choices[i].text is not None
|
|
||||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||||
|
|
||||||
|
|
||||||
@ -929,7 +925,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
|||||||
guided_decoding_backend=guided_decoding_backend))
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 2
|
assert len(completion.choices) == 2
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
assert completion.choices[i].text in TEST_CHOICE
|
assert completion.choices[i].text in TEST_CHOICE
|
||||||
|
|
||||||
@ -1031,12 +1027,14 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
top_logprobs=5,
|
top_logprobs=5,
|
||||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||||
guided_decoding_backend=guided_decoding_backend))
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
|
assert chat_completion.choices[0].logprobs.content is not None
|
||||||
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
|
||||||
# -9999.0 is the minimum logprob returned by OpenAI
|
# -9999.0 is the minimum logprob returned by OpenAI
|
||||||
assert all(
|
for item in top_logprobs:
|
||||||
isinstance(token.logprob, float) and token.logprob >= -9999.0
|
assert item.logprob >= -9999.0, f"Failed (top_logprobs={top_logprobs})"
|
||||||
for token in top_logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -1238,6 +1236,8 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
|||||||
response_format={"type": "json_object"})
|
response_format={"type": "json_object"})
|
||||||
|
|
||||||
content = resp.choices[0].message.content
|
content = resp.choices[0].message.content
|
||||||
|
assert content is not None
|
||||||
|
|
||||||
loaded = json.loads(content)
|
loaded = json.loads(content)
|
||||||
assert loaded == {"result": 2}, loaded
|
assert loaded == {"result": 2}, loaded
|
||||||
|
|
||||||
@ -1365,8 +1365,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||||
list) else prompt
|
list) else prompt
|
||||||
assert (completion.choices[0].text is not None
|
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
||||||
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
|
||||||
logprobs = completion.choices[0].logprobs
|
logprobs = completion.choices[0].logprobs
|
||||||
assert logprobs is not None
|
assert logprobs is not None
|
||||||
assert len(logprobs.text_offset) > 5
|
assert len(logprobs.text_offset) > 5
|
||||||
@ -1407,32 +1406,32 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
|||||||
)
|
)
|
||||||
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
input = [
|
input_texts = [
|
||||||
"The chef prepared a delicious meal.",
|
"The chef prepared a delicious meal.",
|
||||||
]
|
]
|
||||||
|
|
||||||
# test single embedding
|
# test single embedding
|
||||||
embeddings = await client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input,
|
input=input_texts,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert embeddings.data is not None and len(embeddings.data) == 1
|
assert len(embeddings.data) == 1
|
||||||
assert len(embeddings.data[0].embedding) == 4096
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
assert embeddings.usage.completion_tokens == 0
|
assert embeddings.usage.completion_tokens == 0
|
||||||
assert embeddings.usage.prompt_tokens == 9
|
assert embeddings.usage.prompt_tokens == 9
|
||||||
assert embeddings.usage.total_tokens == 9
|
assert embeddings.usage.total_tokens == 9
|
||||||
|
|
||||||
# test using token IDs
|
# test using token IDs
|
||||||
input = [1, 1, 1, 1, 1]
|
input_tokens = [1, 1, 1, 1, 1]
|
||||||
embeddings = await client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=input,
|
input=input_tokens,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert embeddings.data is not None and len(embeddings.data) == 1
|
assert len(embeddings.data) == 1
|
||||||
assert len(embeddings.data[0].embedding) == 4096
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
assert embeddings.usage.completion_tokens == 0
|
assert embeddings.usage.completion_tokens == 0
|
||||||
assert embeddings.usage.prompt_tokens == 5
|
assert embeddings.usage.prompt_tokens == 5
|
||||||
@ -1447,29 +1446,29 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
|||||||
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
|
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||||
model_name: str):
|
model_name: str):
|
||||||
# test List[str]
|
# test List[str]
|
||||||
inputs = [
|
input_texts = [
|
||||||
"The cat sat on the mat.", "A feline was resting on a rug.",
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
||||||
"Stars twinkle brightly in the night sky."
|
"Stars twinkle brightly in the night sky."
|
||||||
]
|
]
|
||||||
embeddings = await client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=inputs,
|
input=input_texts,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert embeddings.data is not None and len(embeddings.data) == 3
|
assert len(embeddings.data) == 3
|
||||||
assert len(embeddings.data[0].embedding) == 4096
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
|
|
||||||
# test List[List[int]]
|
# test List[List[int]]
|
||||||
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||||
[25, 32, 64, 77]]
|
[25, 32, 64, 77]]
|
||||||
embeddings = await client.embeddings.create(
|
embeddings = await client.embeddings.create(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
input=inputs,
|
input=input_tokens,
|
||||||
encoding_format="float",
|
encoding_format="float",
|
||||||
)
|
)
|
||||||
assert embeddings.id is not None
|
assert embeddings.id is not None
|
||||||
assert embeddings.data is not None and len(embeddings.data) == 4
|
assert len(embeddings.data) == 4
|
||||||
assert len(embeddings.data[0].embedding) == 4096
|
assert len(embeddings.data[0].embedding) == 4096
|
||||||
assert embeddings.usage.completion_tokens == 0
|
assert embeddings.usage.completion_tokens == 0
|
||||||
assert embeddings.usage.prompt_tokens == 17
|
assert embeddings.usage.prompt_tokens == 17
|
||||||
|
|||||||
@ -209,9 +209,8 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
temperature=0.0)
|
temperature=0.0)
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 1
|
assert len(completion.choices) == 1
|
||||||
assert completion.choices[0].text is not None and len(
|
assert len(completion.choices[0].text) >= 5
|
||||||
completion.choices[0].text) >= 5
|
|
||||||
assert completion.choices[0].finish_reason == "length"
|
assert completion.choices[0].finish_reason == "length"
|
||||||
assert completion.usage == openai.types.CompletionUsage(
|
assert completion.usage == openai.types.CompletionUsage(
|
||||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||||
|
|||||||
@ -513,7 +513,8 @@ class CompletionLogProbs(OpenAIBaseModel):
|
|||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
|
top_logprobs: List[Optional[Dict[str,
|
||||||
|
float]]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(OpenAIBaseModel):
|
class CompletionResponseChoice(OpenAIBaseModel):
|
||||||
@ -612,7 +613,7 @@ class ChatCompletionResponseChoice(OpenAIBaseModel):
|
|||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Optional[Union[int, str]] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -635,7 +636,7 @@ class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
|||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Optional[Union[int, str]] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -373,13 +373,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||||
top_logprobs = output.logprobs[
|
out_logprobs = output.logprobs[
|
||||||
previous_num_tokens[i]:] if output.logprobs else None
|
previous_num_tokens[i]:] if output.logprobs else None
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, (
|
||||||
|
"Did not output logprobs")
|
||||||
logprobs = self._create_chat_logprobs(
|
logprobs = self._create_chat_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -490,12 +492,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
role = self.get_chat_request_role(request)
|
role = self.get_chat_request_role(request)
|
||||||
for output in final_res.outputs:
|
for output in final_res.outputs:
|
||||||
token_ids = output.token_ids
|
token_ids = output.token_ids
|
||||||
top_logprobs = output.logprobs
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs and request.top_logprobs is not None:
|
||||||
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
logprobs = self._create_chat_logprobs(
|
logprobs = self._create_chat_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from fastapi import Request
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
@ -16,7 +17,6 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
# yapf: enable
|
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -221,7 +221,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
# only return the prompt
|
# only return the prompt
|
||||||
delta_text = res.prompt
|
delta_text = res.prompt
|
||||||
delta_token_ids = res.prompt_token_ids
|
delta_token_ids = res.prompt_token_ids
|
||||||
top_logprobs = res.prompt_logprobs
|
out_logprobs = res.prompt_logprobs
|
||||||
has_echoed[i] = True
|
has_echoed[i] = True
|
||||||
elif (request.echo and request.max_tokens > 0
|
elif (request.echo and request.max_tokens > 0
|
||||||
and not has_echoed[i]):
|
and not has_echoed[i]):
|
||||||
@ -229,7 +229,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
delta_text = res.prompt + output.text
|
delta_text = res.prompt + output.text
|
||||||
delta_token_ids = (res.prompt_token_ids +
|
delta_token_ids = (res.prompt_token_ids +
|
||||||
output.token_ids)
|
output.token_ids)
|
||||||
top_logprobs = res.prompt_logprobs + (output.logprobs
|
out_logprobs = res.prompt_logprobs + (output.logprobs
|
||||||
or [])
|
or [])
|
||||||
has_echoed[i] = True
|
has_echoed[i] = True
|
||||||
else:
|
else:
|
||||||
@ -237,13 +237,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
delta_text = output.text[len(previous_texts[i]):]
|
delta_text = output.text[len(previous_texts[i]):]
|
||||||
delta_token_ids = output.token_ids[
|
delta_token_ids = output.token_ids[
|
||||||
previous_num_tokens[i]:]
|
previous_num_tokens[i]:]
|
||||||
top_logprobs = output.logprobs[previous_num_tokens[
|
out_logprobs = output.logprobs[previous_num_tokens[
|
||||||
i]:] if output.logprobs else None
|
i]:] if output.logprobs else None
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
|
assert out_logprobs is not None, (
|
||||||
|
"Did not output logprobs")
|
||||||
logprobs = self._create_completion_logprobs(
|
logprobs = self._create_completion_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
initial_text_offset=len(previous_texts[i]),
|
initial_text_offset=len(previous_texts[i]),
|
||||||
)
|
)
|
||||||
@ -325,25 +327,23 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
assert request.max_tokens is not None
|
assert request.max_tokens is not None
|
||||||
if request.echo and request.max_tokens == 0:
|
if request.echo and request.max_tokens == 0:
|
||||||
token_ids = prompt_token_ids
|
token_ids = prompt_token_ids
|
||||||
top_logprobs = prompt_logprobs
|
out_logprobs = prompt_logprobs
|
||||||
output_text = prompt_text
|
output_text = prompt_text
|
||||||
elif request.echo and request.max_tokens > 0:
|
elif request.echo and request.max_tokens > 0:
|
||||||
token_ids = prompt_token_ids + output.token_ids
|
token_ids = prompt_token_ids + output.token_ids
|
||||||
top_logprobs = (prompt_logprobs + output.logprobs
|
out_logprobs = (prompt_logprobs + output.logprobs
|
||||||
if request.logprobs is not None else None)
|
if request.logprobs is not None else None)
|
||||||
output_text = prompt_text + output.text
|
output_text = prompt_text + output.text
|
||||||
else:
|
else:
|
||||||
token_ids = output.token_ids
|
token_ids = output.token_ids
|
||||||
top_logprobs = output.logprobs
|
out_logprobs = output.logprobs
|
||||||
output_text = output.text
|
output_text = output.text
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
assert top_logprobs is not None, (
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
"top_logprobs must be provided when logprobs "
|
|
||||||
"is requested")
|
|
||||||
logprobs = self._create_completion_logprobs(
|
logprobs = self._create_completion_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=out_logprobs,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user