mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-26 23:03:50 +08:00
[V1] Logprobs and prompt logprobs support (#9880)
This PR is adding support for sample logprobs & prompt logprobs to vLLM v1. New behavior: - During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order. - In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized. - During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.) - Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer. Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
538fab93cd
commit
0630d4537a
@ -195,8 +195,8 @@ def test_schedule_partial_requests():
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[0] * len(requests),
|
||||
logprob_token_ids_cpu=None,
|
||||
logprobs_cpu=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
)
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
|
||||
90
tests/v1/engine/conftest.py
Normal file
90
tests/v1/engine/conftest.py
Normal file
@ -0,0 +1,90 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
|
||||
TOKENIZER_NAME,
|
||||
DummyOutputProcessorTestVectors,
|
||||
generate_dummy_prompt_logprobs_tensors,
|
||||
generate_dummy_sample_logprobs)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
||||
|
||||
EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, without logprobs
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with no logprobs
|
||||
"""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
|
||||
# Tokenize prompts under test & create dummy generated tokens
|
||||
prompt_tokens = [
|
||||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
|
||||
]
|
||||
generation_tokens = [
|
||||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
|
||||
]
|
||||
# Generate prompt strings
|
||||
prompt_strings = [
|
||||
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
||||
for prompt_tokens in prompt_tokens
|
||||
]
|
||||
prompt_strings_len = [
|
||||
len(prompt_string) for prompt_string in prompt_strings
|
||||
]
|
||||
return DummyOutputProcessorTestVectors(
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_group=init_tokenizer_from_configs(
|
||||
vllm_config.model_config, vllm_config.scheduler_config,
|
||||
vllm_config.parallel_config, vllm_config.lora_config),
|
||||
vllm_config=vllm_config,
|
||||
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
|
||||
prompt_tokens=prompt_tokens,
|
||||
generation_tokens=generation_tokens,
|
||||
prompt_strings=prompt_strings,
|
||||
prompt_strings_len=prompt_strings_len,
|
||||
generation_strings=[
|
||||
text[prompt_len:]
|
||||
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
|
||||
],
|
||||
prompt_logprobs=[],
|
||||
generation_logprobs=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
|
||||
"""Generate output processor dummy test vectors, with logprobs
|
||||
|
||||
Returns:
|
||||
DummyOutputProcessorTestVectors instance with logprobs
|
||||
"""
|
||||
# Build dummy test vectors without logprobs
|
||||
dtv = _build_test_vectors_no_logprobs()
|
||||
# Inject logprobs into dummy test vectors
|
||||
# data structure
|
||||
dtv.generation_logprobs = [
|
||||
generate_dummy_sample_logprobs(
|
||||
sampled_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
|
||||
]
|
||||
dtv.prompt_logprobs = [
|
||||
generate_dummy_prompt_logprobs_tensors(
|
||||
prompt_tokens_list=tokens_list,
|
||||
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
|
||||
]
|
||||
return dtv
|
||||
@ -2,10 +2,11 @@
|
||||
|
||||
import asyncio
|
||||
from contextlib import ExitStack
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
@ -21,13 +22,19 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
|
||||
disable_log_requests=True)
|
||||
|
||||
|
||||
async def generate(engine: AsyncLLM, request_id: str,
|
||||
async def generate(engine: AsyncLLM,
|
||||
request_id: str,
|
||||
output_kind: RequestOutputKind,
|
||||
max_tokens: int) -> Tuple[int, str]:
|
||||
max_tokens: int,
|
||||
prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
|
||||
# Ensure generate doesn't complete too fast for cancellation test.
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
count = 0
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens,
|
||||
output_kind=output_kind,
|
||||
temperature=0)
|
||||
temperature=0,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
async for out in engine.generate(request_id=request_id,
|
||||
prompt="Hello my name is Robert and",
|
||||
sampling_params=sampling_params):
|
||||
@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str,
|
||||
return count, request_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_llm_refuses_prompt_logprobs_with_apc(
|
||||
monkeypatch, output_kind: RequestOutputKind):
|
||||
"""Test passes if AsyncLLM raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
|
||||
# better way to test V1 so that in the future when we switch, we don't
|
||||
# have to change all the tests.
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# Create AsyncLLM engine with APC
|
||||
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
|
||||
enable_prefix_caching=True,
|
||||
gpu_memory_utilization=0.8,
|
||||
disable_log_requests=True)
|
||||
engine = AsyncLLM.from_engine_args(apc_engine_args)
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
# Issue a request with prompt logprobs enabled, which should fail
|
||||
await asyncio.create_task(
|
||||
generate(engine,
|
||||
"request-0",
|
||||
output_kind,
|
||||
10,
|
||||
prompt_logprobs=5))
|
||||
# Validate exception string is correct
|
||||
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
|
||||
finally:
|
||||
# Shut down engine
|
||||
engine.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.asyncio
|
||||
|
||||
23
tests/v1/engine/test_llm_engine.py
Normal file
23
tests/v1/engine/test_llm_engine.py
Normal file
@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
|
||||
"""Test passes if LLMEngine raises an exception when it is configured
|
||||
for automatic prefix caching and it receives a request with
|
||||
prompt_logprobs enabled, which is incompatible."""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
|
||||
"Hello, my name is",
|
||||
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
|
||||
|
||||
# Validate exception string is correct
|
||||
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
|
||||
@ -1,82 +1,47 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
import math
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
|
||||
STOP_STRINGS,
|
||||
DummyOutputProcessorTestVectors,
|
||||
MockEngineCore)
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
|
||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
|
||||
TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config,
|
||||
VLLM_CONFIG.scheduler_config,
|
||||
VLLM_CONFIG.parallel_config,
|
||||
VLLM_CONFIG.lora_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||
|
||||
FULL_STRINGS = [
|
||||
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
||||
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
|
||||
"Nick is the name of my brother in addition to my colleague from Red Hat.",
|
||||
]
|
||||
def _ref_convert_id_to_token(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_id: int,
|
||||
) -> str:
|
||||
"""Reference impl of logprobs detokenization.
|
||||
|
||||
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
|
||||
Args:
|
||||
tokenizer: tokenizer used by the model under test
|
||||
token_id: convert this token id
|
||||
|
||||
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS]
|
||||
PROMPT_LEN = 5
|
||||
PROMPT_TOKENS = [
|
||||
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
|
||||
]
|
||||
GENERATION_TOKENS = [
|
||||
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
|
||||
]
|
||||
PROMPT_STRINGS = [
|
||||
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
|
||||
for prompt_tokens in PROMPT_TOKENS
|
||||
]
|
||||
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS]
|
||||
GENERATION_STRINGS = [
|
||||
text[prompt_len:]
|
||||
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN)
|
||||
]
|
||||
|
||||
|
||||
class MockEngineCore:
|
||||
"""Mock outputs form premade tokens lists."""
|
||||
|
||||
def __init__(self, tokens_list: List[List[int]]):
|
||||
self.tokens_list = tokens_list
|
||||
self.current_idx = 0
|
||||
|
||||
def get_outputs(self) -> List[EngineCoreOutput]:
|
||||
token_idx = self.current_idx
|
||||
self.current_idx += 1
|
||||
|
||||
outputs = []
|
||||
for req_idx, token_ids in enumerate(self.tokens_list):
|
||||
if len(token_ids) > token_idx:
|
||||
output = EngineCoreOutput(request_id=f"request-{req_idx}",
|
||||
new_token_ids=[token_ids[token_idx]],
|
||||
finished=False)
|
||||
if token_idx == len(token_ids) - 1:
|
||||
output.finished = True
|
||||
output.finish_reason = "stopped"
|
||||
outputs.append(output)
|
||||
|
||||
return outputs
|
||||
Returns:
|
||||
String representation of input token id
|
||||
"""
|
||||
return tokenizer.convert_ids_to_tokens(token_id) or ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
|
||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
@ -94,10 +59,10 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False))
|
||||
for idx, (
|
||||
prompt,
|
||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||
include_stop_str_in_output=False,
|
||||
)) for idx, (prompt, prompt_tokens) in enumerate(
|
||||
zip(dummy_test_vectors.prompt_strings,
|
||||
dummy_test_vectors.prompt_tokens))
|
||||
]
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
@ -113,7 +78,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
break
|
||||
|
||||
# Step the Detokenizer.
|
||||
processed_outputs = output_processor.process_outputs(outputs, )
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
requests_to_abort = processed_outputs.reqs_to_abort
|
||||
assert len(requests_to_abort) == 0
|
||||
@ -132,7 +97,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
|
||||
zip(GENERATION_STRINGS, GENERATION_TOKENS)):
|
||||
zip(dummy_test_vectors.generation_strings,
|
||||
dummy_test_vectors.generation_tokens)):
|
||||
gen_str = gen_strings[f"request-{idx}"]
|
||||
gen_toks = gen_tokens[f"request-{idx}"]
|
||||
|
||||
@ -143,15 +109,390 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
def test_stop_string(include_stop_str_in_output: bool):
|
||||
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
|
||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||
def _validate_logprobs(
|
||||
gen_tokens: Dict[str, List[int]],
|
||||
gen_logprobs: Dict[str, Optional[SampleLogprobs]],
|
||||
gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]],
|
||||
gen_cumulative_logprob: Dict[str, float],
|
||||
dtv: DummyOutputProcessorTestVectors,
|
||||
request_id_list: List[str],
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
) -> None:
|
||||
for req_idx, req_id in enumerate(request_id_list):
|
||||
new_tokens = gen_tokens[req_id]
|
||||
logprobs = gen_logprobs[req_id]
|
||||
prompt_logprobs = gen_prompt_logprobs[req_id]
|
||||
cumulative_logprob = gen_cumulative_logprob[req_id]
|
||||
prompt_token_ids = dtv.prompt_tokens[req_idx]
|
||||
ref_logprobs = dtv.generation_logprobs[req_idx]
|
||||
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
|
||||
if num_sample_logprobs is not None:
|
||||
# Validate sample logprobs
|
||||
assert logprobs is not None, (f"Request {req_id} requires sample"
|
||||
" logprobs but sample logprobs are"
|
||||
" None.")
|
||||
# Require num sampled tokens to match num
|
||||
# sampled logprobs - especially important
|
||||
# to check since the detokenizer can cause
|
||||
# a request to finish early due to a stop
|
||||
# string being hit
|
||||
num_new_tokens = len(new_tokens)
|
||||
len_sample_logprobs = len(logprobs)
|
||||
assert num_new_tokens == len_sample_logprobs, (
|
||||
f"Request {req_id} has {num_new_tokens}"
|
||||
" completion tokens but has"
|
||||
f" {len_sample_logprobs} sample logprobs.")
|
||||
ref_cumulative_logprob = 0.0
|
||||
for idx, (sampled_token,
|
||||
pos_logprob_dict) in enumerate(zip(new_tokens,
|
||||
logprobs)):
|
||||
# Break out the reference log probability value &
|
||||
# logprob token id tensors associated with this
|
||||
# position in the completion. Also break out the
|
||||
# sampled token ranks
|
||||
(ref_pos_logprob_toks, ref_pos_logprob_vals,
|
||||
ref_sampled_token_rank) = ref_logprobs[idx]
|
||||
# For each position in the completion sequence,
|
||||
# ensure the actual sampled token is among the
|
||||
# logprobs
|
||||
assert sampled_token in pos_logprob_dict, (
|
||||
f"Sampled token {sampled_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
|
||||
# Validate number of sample logprobs
|
||||
num_lp_toks = len(pos_logprob_dict)
|
||||
assert (num_lp_toks == num_sample_logprobs
|
||||
or num_lp_toks == num_sample_logprobs +
|
||||
1), ("Valid numbers of sample logprobs are"
|
||||
f" {num_sample_logprobs} or"
|
||||
f" {num_sample_logprobs+1} but"
|
||||
f" {num_lp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
|
||||
# Validate sampled token logprob rank
|
||||
smp_lp = pos_logprob_dict[sampled_token]
|
||||
smp_lp_rank = smp_lp.rank
|
||||
assert (ref_sampled_token_rank == smp_lp_rank), (
|
||||
"Sampled token logprob rank"
|
||||
f" {smp_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_sampled_token_rank}"
|
||||
f" in Logprob {smp_lp}")
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct log probabilities and valid
|
||||
# rankings
|
||||
rank_one_appears = False
|
||||
for jdx in range(1, len(ref_pos_logprob_toks)):
|
||||
# Iterate over the (logprob val,logprob tok id)
|
||||
# pairs expected by the test fixture at this
|
||||
# position in the completion.
|
||||
ref_lp_val = ref_pos_logprob_vals[jdx]
|
||||
ref_tok_id = ref_pos_logprob_toks[jdx]
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
lp = pos_logprob_dict[ref_tok_id]
|
||||
lp_val = lp.logprob
|
||||
lp_rank = lp.rank
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if lp_rank == 1 else rank_one_appears)
|
||||
|
||||
# Rank must be >= 1
|
||||
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
|
||||
f" rank {lp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(lp_val, ref_lp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {lp_val} but {ref_lp_val} was"
|
||||
f" expected. Logprob: {lp}")
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
|
||||
# Validate logprobs detokenization
|
||||
for lp_tok in pos_logprob_dict:
|
||||
# Confirm that sample logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[lp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, lp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Sampled logprob token id {lp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
|
||||
ref_cumulative_logprob += pos_logprob_dict[
|
||||
sampled_token].logprob
|
||||
# Assert that cumulative logprobs are correct
|
||||
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
|
||||
else:
|
||||
# Sample logprobs disabled for this request
|
||||
assert logprobs is None
|
||||
assert cumulative_logprob is None
|
||||
|
||||
if num_prompt_logprobs is not None:
|
||||
# Validate prompt logprobs
|
||||
assert prompt_logprobs is not None, (
|
||||
f"Request {req_id} requires prompt"
|
||||
" logprobs but prompt logprobs are"
|
||||
" None.")
|
||||
# Require num prompt tokens to match num
|
||||
# prompt logprobs
|
||||
num_prompt_tokens = len(prompt_token_ids)
|
||||
len_prompt_logprobs = len(prompt_logprobs)
|
||||
assert num_prompt_tokens == len_prompt_logprobs, (
|
||||
f"Request {req_id} has {num_prompt_tokens}"
|
||||
" prompt tokens but has"
|
||||
f" {len_prompt_logprobs} prompt logprobs.")
|
||||
# First prompt logprob is None
|
||||
first_plp_dict = prompt_logprobs[0]
|
||||
assert first_plp_dict is None, (
|
||||
f"Request {req_id} first prompt logprob"
|
||||
f" should be None but has following value"
|
||||
f" instead: {first_plp_dict}")
|
||||
# Break out the reference prompt log prob value &
|
||||
# logprob token id matrices for the whole prompt.
|
||||
# Also break out the prompt token rank vector
|
||||
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
|
||||
ref_prompt_token_ranks) = ref_prompt_logprobs
|
||||
for idx, (prompt_token, pos_logprob_dict) in enumerate(
|
||||
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
|
||||
|
||||
# Break out the reference prompt log prob value
|
||||
# vector, prompt logprob token id vector, and
|
||||
# prompt token rank at the current position.
|
||||
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
|
||||
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
|
||||
ref_prompt_logprob_vals[idx, :],
|
||||
ref_prompt_token_ranks[idx])
|
||||
|
||||
# For each position in the prompt sequence,
|
||||
# ensure the actual prompt token is among the
|
||||
# logprobs
|
||||
assert prompt_token in pos_logprob_dict, (
|
||||
f"Prompt token {prompt_token} not"
|
||||
f" present in logprob at index {idx}")
|
||||
# Validate number of prompt logprobs
|
||||
num_plp_toks = len(pos_logprob_dict)
|
||||
assert (num_plp_toks == num_prompt_logprobs
|
||||
or num_plp_toks == num_prompt_logprobs +
|
||||
1), ("Valid numbers of prompt logprobs are"
|
||||
f" {num_prompt_logprobs} or"
|
||||
f" {num_prompt_logprobs+1} but"
|
||||
f" {num_plp_toks} logprobs found at"
|
||||
f" position {idx}. Logprobs dict:"
|
||||
f" {pos_logprob_dict}")
|
||||
|
||||
# Validate prompt token logprob rank
|
||||
prmpt_tok_lp = pos_logprob_dict[prompt_token]
|
||||
prmpt_tok_lp_rank = prmpt_tok_lp.rank
|
||||
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
|
||||
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
|
||||
"Prompt token logprob rank"
|
||||
f" {prmpt_tok_lp_rank} does not match"
|
||||
" correct value"
|
||||
f" {ref_prmpt_tok_lp_rank}"
|
||||
f" in Logprob {prmpt_tok_lp}")
|
||||
|
||||
# Validate that the logprob processor yields
|
||||
# the correct prompt log probs and valid
|
||||
# rankings
|
||||
rank_one_appears = False
|
||||
for jdx in range(1, len(ref_pos_prompt_logprob_toks)):
|
||||
# Iterate over the (logprob val,logprob tok id)
|
||||
# pairs expected by the test fixture at this
|
||||
# position in the completion.
|
||||
ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx])
|
||||
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
|
||||
assert ref_tok_id in pos_logprob_dict, (
|
||||
f"Expected token {ref_tok_id} to be"
|
||||
f" in logprob dict but it is not.")
|
||||
|
||||
# Extract actually-generated logprob
|
||||
# info
|
||||
plp = pos_logprob_dict[ref_tok_id]
|
||||
plp_val = plp.logprob
|
||||
plp_rank = plp.rank
|
||||
|
||||
# A "top" (rank 1) logprob must be
|
||||
# present
|
||||
rank_one_appears = (True
|
||||
if plp_rank == 1 else rank_one_appears)
|
||||
|
||||
# Rank must be >= 1
|
||||
assert plp_rank >= 1, (
|
||||
f"Logprob {plp} has invalid"
|
||||
f" rank {plp_rank} < 1."
|
||||
f" Logprob dict: {pos_logprob_dict}")
|
||||
|
||||
# Validate log probability
|
||||
assert math.isclose(plp_val, ref_plp_val), (
|
||||
f"Token id {ref_tok_id} appears in logprobs dict"
|
||||
f" at position {idx} in completion with log"
|
||||
f" probability {plp_val} but {ref_plp_val} was"
|
||||
f" expected. Logprob: {plp}")
|
||||
|
||||
assert rank_one_appears, (f"No Logprob has rank 1"
|
||||
" in the following Logprob"
|
||||
f" dict: {pos_logprob_dict}")
|
||||
|
||||
# Validate prompt logprob detokenization
|
||||
for plp_tok in pos_logprob_dict:
|
||||
# Confirm that prompt logprob decoded token matches
|
||||
# the logprob token id at this sequence position
|
||||
decoded_token = pos_logprob_dict[plp_tok].decoded_token
|
||||
ref_decoded_token = _ref_convert_id_to_token(
|
||||
dtv.tokenizer, plp_tok)
|
||||
assert decoded_token == ref_decoded_token, (
|
||||
f"Prompt logprob token id {plp_tok} decodes to"
|
||||
f" {ref_decoded_token} but Logprob decoded"
|
||||
f" token is {decoded_token} instead"
|
||||
f" (at position {idx})")
|
||||
else:
|
||||
# Prompt logprobs disabled for this request
|
||||
assert prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"request_output_kind",
|
||||
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs",
|
||||
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_logprobs_processor(request_output_kind: RequestOutputKind,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int],
|
||||
dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=None if num_sample_logprobs is None else
|
||||
dummy_test_vectors.generation_logprobs,
|
||||
prompt_logprobs_raw=None
|
||||
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(request_id=request_id_list[idx],
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
sampling_params=SamplingParams(
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=False,
|
||||
output_kind=request_output_kind,
|
||||
stop=[],
|
||||
include_stop_str_in_output=False,
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
)) for idx, (prompt, prompt_tokens) in enumerate(
|
||||
zip(dummy_test_vectors.prompt_strings,
|
||||
dummy_test_vectors.prompt_tokens))
|
||||
]
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
for request in requests:
|
||||
output_processor.add_request(request)
|
||||
|
||||
gen_tokens = {}
|
||||
gen_logprobs = {}
|
||||
gen_prompt_logprobs = {}
|
||||
gen_cumulative_logprobs = {}
|
||||
while True:
|
||||
# Mock output from the EngineCore.
|
||||
outputs = engine_core.get_outputs()
|
||||
if len(outputs) == 0:
|
||||
break
|
||||
|
||||
# Step the logprobs processor.
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
request_outputs = processed_outputs.request_outputs
|
||||
requests_to_abort = processed_outputs.reqs_to_abort
|
||||
assert len(requests_to_abort) == 0
|
||||
|
||||
# Update tracking.
|
||||
for request_output in request_outputs:
|
||||
request_id = request_output.request_id
|
||||
new_tokens = request_output.outputs[0].token_ids
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
if request_id not in gen_logprobs:
|
||||
# Start tracking sample and prompt logprobs for this request
|
||||
gen_tokens[request_id] = new_tokens
|
||||
gen_logprobs[request_id] = logprobs
|
||||
gen_prompt_logprobs[request_id] = prompt_logprobs
|
||||
else:
|
||||
# Extend logprobs tracker
|
||||
gen_tokens[request_id].extend(new_tokens)
|
||||
lp = gen_logprobs[request_id]
|
||||
plp = gen_prompt_logprobs[request_id]
|
||||
if lp:
|
||||
lp.extend(logprobs)
|
||||
if plp:
|
||||
plp.extend(prompt_logprobs)
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs,
|
||||
num_prompt_logprobs)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
@pytest.mark.parametrize("num_sample_logprobs",
|
||||
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
|
||||
@pytest.mark.parametrize("num_prompt_logprobs",
|
||||
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
|
||||
def test_stop_string(include_stop_str_in_output: bool,
|
||||
num_sample_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int], dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=False)
|
||||
engine_core = MockEngineCore(
|
||||
tokens_list=dummy_test_vectors.generation_tokens,
|
||||
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
|
||||
if num_sample_logprobs else None,
|
||||
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
|
||||
if num_prompt_logprobs else None)
|
||||
|
||||
# Make N requests.
|
||||
request_id_list = [
|
||||
f"request-{idx}"
|
||||
for idx in range(len(dummy_test_vectors.prompt_strings))
|
||||
]
|
||||
requests = [
|
||||
EngineCoreRequest(
|
||||
request_id=f"request-{idx}",
|
||||
request_id=request_id_list[idx],
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_tokens,
|
||||
arrival_time=0,
|
||||
@ -166,9 +507,11 @@ def test_stop_string(include_stop_str_in_output: bool):
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
stop=STOP_STRINGS,
|
||||
include_stop_str_in_output=include_stop_str_in_output,
|
||||
)) for idx, (
|
||||
prompt,
|
||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||
logprobs=num_sample_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
)) for idx, (prompt, prompt_tokens) in enumerate(
|
||||
zip(dummy_test_vectors.prompt_strings,
|
||||
dummy_test_vectors.prompt_tokens))
|
||||
]
|
||||
|
||||
# Add requests to the detokenizer.
|
||||
@ -176,6 +519,10 @@ def test_stop_string(include_stop_str_in_output: bool):
|
||||
output_processor.add_request(request)
|
||||
|
||||
gen_strings = {}
|
||||
gen_tokens = {}
|
||||
gen_logprobs = {}
|
||||
gen_prompt_logprobs = {}
|
||||
gen_cumulative_logprobs = {}
|
||||
aborted = []
|
||||
while True:
|
||||
# Mock output from the EngineCore.
|
||||
@ -199,14 +546,29 @@ def test_stop_string(include_stop_str_in_output: bool):
|
||||
|
||||
request_id = request_output.request_id
|
||||
new_text = request_output.outputs[0].text
|
||||
new_tokens = request_output.outputs[0].token_ids
|
||||
prompt_logprobs = request_output.prompt_logprobs
|
||||
logprobs = request_output.outputs[0].logprobs
|
||||
gen_cumulative_logprobs[request_id] = request_output.outputs[
|
||||
0].cumulative_logprob
|
||||
if request_id not in gen_strings:
|
||||
gen_strings[request_id] = new_text
|
||||
gen_tokens[request_id] = new_tokens
|
||||
gen_logprobs[request_id] = logprobs
|
||||
gen_prompt_logprobs[request_id] = prompt_logprobs
|
||||
else:
|
||||
gen_strings[request_id] += new_text
|
||||
gen_tokens[request_id].extend(new_tokens)
|
||||
lp = gen_logprobs[request_id]
|
||||
plp = gen_prompt_logprobs[request_id]
|
||||
if lp:
|
||||
lp.extend(logprobs)
|
||||
if plp:
|
||||
plp.extend(prompt_logprobs)
|
||||
|
||||
# Confirmed tracked values matches what we expected.
|
||||
for idx, (ref_gen_str,
|
||||
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)):
|
||||
for idx, (ref_gen_str, stop_str) in enumerate(
|
||||
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
|
||||
|
||||
# Request should be aborted.
|
||||
request_id = f"request-{idx}"
|
||||
@ -227,13 +589,20 @@ def test_stop_string(include_stop_str_in_output: bool):
|
||||
assert gen_str == ref_str_exc_stop, (
|
||||
f"{gen_str=}, {ref_str_exc_stop=}")
|
||||
|
||||
# Confirmed tracked logprobs match what we expect
|
||||
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
|
||||
gen_cumulative_logprobs, dummy_test_vectors,
|
||||
request_id_list, num_sample_logprobs,
|
||||
num_prompt_logprobs)
|
||||
|
||||
assert output_processor.get_num_unfinished_requests() == 0
|
||||
assert not output_processor.has_unfinished_requests()
|
||||
|
||||
|
||||
def test_iteration_stats():
|
||||
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True)
|
||||
engine_core = MockEngineCore(GENERATION_TOKENS)
|
||||
def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
@ -248,13 +617,13 @@ def test_iteration_stats():
|
||||
eos_token_id=None,
|
||||
lora_request=None,
|
||||
sampling_params=SamplingParams(),
|
||||
) for idx, (
|
||||
prompt,
|
||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||
) for idx, (prompt, prompt_tokens) in enumerate(
|
||||
zip(dummy_test_vectors.prompt_strings,
|
||||
dummy_test_vectors.prompt_tokens))
|
||||
]
|
||||
|
||||
# Add all requests except one to the OutputProcessor.
|
||||
num_active = len(GENERATION_TOKENS) - 1
|
||||
num_active = len(dummy_test_vectors.generation_tokens) - 1
|
||||
for request in requests[:num_active]:
|
||||
output_processor.add_request(request)
|
||||
inactive_request = requests[num_active]
|
||||
@ -263,8 +632,10 @@ def test_iteration_stats():
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
total_prompt_tokens = sum(
|
||||
[len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]])
|
||||
total_prompt_tokens = sum([
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@ -283,7 +654,7 @@ def test_iteration_stats():
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1])
|
||||
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
||||
382
tests/v1/engine/utils.py
Normal file
382
tests/v1/engine/utils.py
Normal file
@ -0,0 +1,382 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
# Number of sample logprobs to request when testing sample logprobs
|
||||
NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5
|
||||
# Number of prompt logprobs to request when testing prompt logprobs
|
||||
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
|
||||
|
||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
|
||||
FULL_STRINGS = [
|
||||
"My name is Robert from Neural Magic and I love working on vLLM so much!",
|
||||
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
|
||||
"Nick is the name of my brother in addition to my colleague from Red Hat.",
|
||||
]
|
||||
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
|
||||
PROMPT_LEN = 5
|
||||
|
||||
PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
|
||||
"supported on VLLM V1.")
|
||||
|
||||
random.seed(42)
|
||||
|
||||
|
||||
def _create_random_top_logprob_test_vector(
|
||||
num_logprobs: int,
|
||||
lower: float,
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random vector of top logprob float values.
|
||||
|
||||
Use to create fake sample logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
logprobs to be sorted in descending order, something
|
||||
which is omitted in this function.
|
||||
|
||||
Args:
|
||||
num_logprobs: number of top logprobs
|
||||
lower: lower range of logprob float values
|
||||
upper: upper range of logprob float values
|
||||
|
||||
Returns:
|
||||
1D length-`num_logprobs` torch Tensor of float logprob values
|
||||
"""
|
||||
return torch.rand(num_logprobs) * (upper - lower) + lower
|
||||
|
||||
|
||||
def _create_random_top_logprob_test_matrix(
|
||||
shape: Tuple,
|
||||
lower: float,
|
||||
upper: float,
|
||||
) -> torch.Tensor:
|
||||
"""Create a random matrix of top logprob float values.
|
||||
|
||||
Use to create fake prompt logprobs for testing.
|
||||
|
||||
Note that a real production scenario would require
|
||||
logprobs to be sorted in descending order along rows,
|
||||
something which is omitted in this function.
|
||||
|
||||
Args:
|
||||
shape: (num_tokens,num_logprobs) tuple representing
|
||||
matrix shape
|
||||
lower: lower range of logprob float values
|
||||
upper: upper range of logprob float values
|
||||
|
||||
Returns:
|
||||
2D num_tokens x num_logprobs torch Tensor of float logprob values
|
||||
"""
|
||||
return torch.rand(*shape) * (upper - lower) + lower
|
||||
|
||||
|
||||
def _create_random_top_token_test_vector(
|
||||
num_logprobs: int,
|
||||
lower: int,
|
||||
upper: int,
|
||||
sampled_token_id: int,
|
||||
adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]:
|
||||
"""Create a random vector of top logprob token indices
|
||||
|
||||
Use to create fake sample logprobs for testing. The sampled token
|
||||
ID must always be one of the top logprobs, which this dummy test
|
||||
vector generator enforces. OpenAI API
|
||||
compatible engines must be able to return an additional sample
|
||||
logprob for the sampled token if the sampled token was not
|
||||
among the top sample logprobs; `adjust_num_logprobs` emulates
|
||||
this behavior by increasing the vector length by 1 if
|
||||
`adjust_num_logprobs` is set.
|
||||
|
||||
Args:
|
||||
num_logprobs: number of top logprobs
|
||||
lower: lower range of token ids
|
||||
upper: upper range of token ids
|
||||
sampled_token_id: the token actually sampled
|
||||
adjust_num_logprobs: if True, emulate situation where sampled
|
||||
token logprob must be injected into top
|
||||
logprobs
|
||||
|
||||
Returns:
|
||||
1D length-x torch Tensor of token ids where x is
|
||||
`num_logprobs+1` if `adjust_num_logprobs` and
|
||||
`num_logprobs` otherwise
|
||||
sampled_token_rank: the rank of sampled_token_id in the vocab
|
||||
vector when sorted in descending order by
|
||||
logprob
|
||||
"""
|
||||
|
||||
# Calculate the final number of logprobs required
|
||||
total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs
|
||||
|
||||
# Generate random indices using torch
|
||||
choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower
|
||||
|
||||
# Ensure the sampled token ID is included in the tensor
|
||||
choice_tensor[0] = sampled_token_id
|
||||
|
||||
# Check if the sampled_token_id occurs in choice_tensor[1:]
|
||||
if sampled_token_id in choice_tensor[1:]:
|
||||
sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero(
|
||||
as_tuple=True)[0].item()
|
||||
else:
|
||||
# If not found, assign a random int between num_logprobs and 50700
|
||||
sampled_token_rank = random.randint(num_logprobs, 50700)
|
||||
|
||||
return choice_tensor, sampled_token_rank
|
||||
|
||||
|
||||
def _create_random_top_token_test_matrix(
|
||||
shape: Tuple[int, int],
|
||||
lower: int,
|
||||
upper: int,
|
||||
tokens_list: List[int],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Create a random matrix of top logprob token indices
|
||||
|
||||
Use to create fake prompt logprobs for testing.
|
||||
|
||||
Token ids are generated randomly and sampled without
|
||||
replacement.
|
||||
|
||||
Args:
|
||||
shape: (num_tokens, num_logprobs) tuple representing
|
||||
matrix shape
|
||||
lower: lower range of token ids
|
||||
upper: upper range of token ids
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
|
||||
- 1D tensor of ranks of prompt tokens in their respective
|
||||
rows, or random values
|
||||
"""
|
||||
num_elements = shape[0] * shape[1]
|
||||
choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
|
||||
matrix = torch.cat(
|
||||
(torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
|
||||
choice_tensor.view(shape)),
|
||||
dim=1)
|
||||
|
||||
# Initialize the tensor for storing the ranks
|
||||
prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)
|
||||
|
||||
# Iterate over each row to check presence of
|
||||
# tokens_list[rdx] and determine its index
|
||||
for rdx in range(shape[0]):
|
||||
row = matrix[rdx,
|
||||
1:] # Skip the first column as it contains the token list
|
||||
token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
|
||||
if token_index.numel() > 0:
|
||||
prompt_token_ranks[rdx] = token_index.item()
|
||||
else:
|
||||
prompt_token_ranks[rdx] = random.randint(shape[1], 50700)
|
||||
|
||||
return matrix, prompt_token_ranks
|
||||
|
||||
|
||||
def decode_token(
|
||||
tok_id: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> str:
|
||||
"""Reproduce the process of detokenizing a token for testing purposes.
|
||||
|
||||
Args:
|
||||
tok_id: token id to detokenize
|
||||
tokenizer: tokenizer to use for detokenization
|
||||
|
||||
Returns:
|
||||
string representation of token
|
||||
"""
|
||||
return tokenizer.convert_ids_to_tokens(tok_id)
|
||||
|
||||
|
||||
def generate_dummy_sample_logprobs(
|
||||
sampled_tokens_list: List,
|
||||
num_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> List[Tuple[List[int], List[float], int]]:
|
||||
"""Generate dummy sample logprobs
|
||||
|
||||
Generate a test data structure which imitates the list of sample logprobs
|
||||
which would be assembled in the engine core during decode phase.
|
||||
|
||||
Args:
|
||||
sampled_tokens_list: list of sampled tokens
|
||||
num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token
|
||||
tokenizer: model tokenizer to use for detokenization
|
||||
|
||||
Returns
|
||||
List of (top token ids vector, logprobs vector, sampled token rank)
|
||||
Python lists tuples; in each tuple the logprobs and top token ids
|
||||
vectors have the same length which is either `num_logprobs` or
|
||||
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
|
||||
sampled token within the vocab vector when sorted by logprob in
|
||||
descending order.
|
||||
"""
|
||||
res = []
|
||||
for sampled_token_id in sampled_tokens_list:
|
||||
(
|
||||
token_vector,
|
||||
sampled_token_rank,
|
||||
) = _create_random_top_token_test_vector(num_logprobs, 0,
|
||||
len(tokenizer.vocab) - 1,
|
||||
sampled_token_id)
|
||||
|
||||
res.append(
|
||||
(token_vector,
|
||||
_create_random_top_logprob_test_vector(num_logprobs + 1, -100,
|
||||
0), sampled_token_rank))
|
||||
|
||||
# Convert tensors in the list tuples to Python lists
|
||||
res_list_format = [
|
||||
(log_probs_tensor.tolist(), token_ids_tensor.tolist(),
|
||||
sampled_token_rank)
|
||||
for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
|
||||
]
|
||||
|
||||
return res_list_format
|
||||
|
||||
|
||||
def generate_dummy_prompt_logprobs_tensors(
|
||||
prompt_tokens_list: List,
|
||||
num_logprobs: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> LogprobsTensors:
|
||||
"""Generate dummy prompt logprobs tensors
|
||||
|
||||
Generate a test data structure which imitates the torch Tensors of prompt
|
||||
logprobs which would be assembled in the engine core during chunked
|
||||
prefill.
|
||||
|
||||
Args:
|
||||
prompt_tokens_list: list of prompt tokens
|
||||
num_logprobs: return `num_logprobs` logprobs per token
|
||||
tokenizer: model tokenizer to use for detokenization
|
||||
|
||||
Returns
|
||||
Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor,
|
||||
where both matrices have dimensions
|
||||
num_prompt_tokens x num_logprobs
|
||||
"""
|
||||
# For now, assume the whole prompt is processed in one chunk; thus,
|
||||
# the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`.
|
||||
# Prior to injecting `None` at the beginning of prompt logprobs (which
|
||||
# happens later in the detokenizer, not here), the prompt logprobs in
|
||||
# the ith position are predicting the probability distribution of the
|
||||
# prompt token in (i+1)st position. Thus, we concat
|
||||
# `prompt_tokens_list[1:]` to the dummy token ids, just as the engine
|
||||
# would.
|
||||
num_prompt_logprobs = len(prompt_tokens_list) - 1
|
||||
(
|
||||
token_vector,
|
||||
prompt_token_ranks,
|
||||
) = _create_random_top_token_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs), 0,
|
||||
len(tokenizer.vocab) - 1, prompt_tokens_list[1:])
|
||||
return LogprobsTensors(
|
||||
token_vector,
|
||||
_create_random_top_logprob_test_matrix(
|
||||
(num_prompt_logprobs, num_logprobs + 1), -100, 0),
|
||||
prompt_token_ranks)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DummyOutputProcessorTestVectors:
|
||||
"""Dummy test vectors for output processor tests"""
|
||||
tokenizer: GeneralTokenizerType
|
||||
tokenizer_group: BaseTokenizerGroup
|
||||
vllm_config: EngineArgs
|
||||
full_tokens: List[List[int]] # Prompt + generated tokens
|
||||
prompt_tokens: List[List[int]]
|
||||
generation_tokens: List[List[int]]
|
||||
# Each request is associated with a tuple of
|
||||
# (top tokens, top logprobs, ranks) prompt logprobs tensors
|
||||
prompt_logprobs: List[LogprobsTensors]
|
||||
# Each request is associated with a sample logprobs; a request's
|
||||
# sample logprobs are a list of (top tokens, top logprobs, ranks)
|
||||
# sample logprobs tensors at each sequence position
|
||||
generation_logprobs: List[List[Tuple[List[int], List[float], int]]]
|
||||
prompt_strings: List[str]
|
||||
prompt_strings_len: List[int]
|
||||
generation_strings: List[str]
|
||||
|
||||
|
||||
class MockEngineCore:
|
||||
"""Mock engine core outputs form premade tokens lists."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokens_list: List[List[int]],
|
||||
# For each request, for each sampled token offset,
|
||||
# a tuple of
|
||||
# (list of topk token ids, list of sample logprob vals, rank)
|
||||
generated_logprobs_raw: Optional[List[List[Tuple[List[int],
|
||||
List[float],
|
||||
int]]]] = None,
|
||||
# For each request, a tuple of
|
||||
# (prompt logprob val matrix, prompt logprob tok id matrix);
|
||||
# each matrix has dimensions
|
||||
# (num prompt toks) x (num prompt logprobs+1)
|
||||
prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None,
|
||||
) -> None:
|
||||
self.tokens_list = tokens_list
|
||||
self.current_idx = 0
|
||||
self.generated_logprobs_raw = generated_logprobs_raw
|
||||
self.do_logprobs = generated_logprobs_raw is not None
|
||||
self.prompt_logprobs_raw = prompt_logprobs_raw
|
||||
self.do_prompt_logprobs = prompt_logprobs_raw is not None
|
||||
|
||||
def get_outputs(self) -> List[EngineCoreOutput]:
|
||||
do_logprobs = self.do_logprobs
|
||||
do_prompt_logprobs = self.do_prompt_logprobs
|
||||
token_idx = self.current_idx
|
||||
|
||||
outputs = []
|
||||
for req_idx, token_ids in enumerate(self.tokens_list):
|
||||
if len(token_ids) > token_idx:
|
||||
if do_logprobs:
|
||||
assert self.generated_logprobs_raw is not None
|
||||
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
|
||||
self.generated_logprobs_raw[req_idx][token_idx])
|
||||
logprobs = LogprobsLists(
|
||||
[logprobs_token_ids_],
|
||||
[logprobs_],
|
||||
[sampled_token_ranks_],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
if do_prompt_logprobs:
|
||||
if self.current_idx == 0:
|
||||
assert self.prompt_logprobs_raw is not None
|
||||
prompt_logprobs = self.prompt_logprobs_raw[req_idx]
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
else:
|
||||
prompt_logprobs = None
|
||||
output = EngineCoreOutput(
|
||||
request_id=f"request-{req_idx}",
|
||||
new_token_ids=[token_ids[token_idx]],
|
||||
new_logprobs=logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs,
|
||||
)
|
||||
if token_idx == len(token_ids) - 1:
|
||||
output.finish_reason = FinishReason.STOP
|
||||
outputs.append(output)
|
||||
|
||||
self.current_idx += 1
|
||||
return outputs
|
||||
0
tests/v1/entrypoints/__init__.py
Normal file
0
tests/v1/entrypoints/__init__.py
Normal file
161
tests/v1/entrypoints/conftest.py
Normal file
161
tests/v1/entrypoints/conftest.py
Normal file
@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prompts():
|
||||
return [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_token_ids():
|
||||
return [
|
||||
[0],
|
||||
[0, 1],
|
||||
[0, 2, 1],
|
||||
[0, 3, 1, 2],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_regex():
|
||||
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"skills": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"company": {
|
||||
"type": "string"
|
||||
},
|
||||
"duration": {
|
||||
"type": "number"
|
||||
},
|
||||
"position": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["company", "position"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["name", "age", "skills", "work_history"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_complex_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 100 # Numeric range
|
||||
},
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern":
|
||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["score", "grade", "email", "tags"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_definition_json_schema():
|
||||
return {
|
||||
'$defs': {
|
||||
'Step': {
|
||||
'properties': {
|
||||
'explanation': {
|
||||
'title': 'Explanation',
|
||||
'type': 'string'
|
||||
},
|
||||
'output': {
|
||||
'title': 'Output',
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['explanation', 'output'],
|
||||
'title': 'Step',
|
||||
'type': 'object'
|
||||
}
|
||||
},
|
||||
'properties': {
|
||||
'steps': {
|
||||
'items': {
|
||||
'$ref': '#/$defs/Step'
|
||||
},
|
||||
'title': 'Steps',
|
||||
'type': 'array'
|
||||
},
|
||||
'final_answer': {
|
||||
'title': 'Final Answer',
|
||||
'type': 'string'
|
||||
}
|
||||
},
|
||||
'required': ['steps', 'final_answer'],
|
||||
'title': 'MathReasoning',
|
||||
'type': 'object'
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guided_choice():
|
||||
return [
|
||||
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
|
||||
"Ruby", "Swift", "Kotlin"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sql_statements():
|
||||
return ("""
|
||||
start: select_statement
|
||||
select_statement: "SELECT" column "from" table "where" condition
|
||||
column: "col_1" | "col_2"
|
||||
table: "table_1" | "table_2"
|
||||
condition: column "=" number
|
||||
number: "1" | "2"
|
||||
""")
|
||||
475
tests/v1/entrypoints/openai/test_completion.py
Normal file
475
tests/v1/entrypoints/openai/test_completion.py
Normal file
@ -0,0 +1,475 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from openai import BadRequestError
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
# any model with a chat template should work here
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
"128",
|
||||
"--enforce-eager"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module",
|
||||
params=[["--no-enable-prefix-caching"],
|
||||
[
|
||||
"--no-enable-prefix-caching",
|
||||
"--disable-frontend-multiprocessing"
|
||||
]])
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.extend(request.param)
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client(server):
|
||||
async with server.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt="Hello, my name is",
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == "length"
|
||||
assert completion.usage == openai.types.CompletionUsage(
|
||||
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
||||
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 1
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=0,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=21,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
||||
# both Completion API and Chat Completion API
|
||||
logprobs=30,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(completion.choices[0].text) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
|
||||
(MODEL_NAME, 0),
|
||||
(MODEL_NAME, 1),
|
||||
(MODEL_NAME, None)])
|
||||
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str,
|
||||
prompt_logprobs: Optional[int]):
|
||||
params: Dict = {
|
||||
"prompt": ["A robot may not injure another robot", "My name is"],
|
||||
"model": model_name,
|
||||
}
|
||||
if prompt_logprobs is not None:
|
||||
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
|
||||
|
||||
if prompt_logprobs is not None and prompt_logprobs < 0:
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(**params)
|
||||
else:
|
||||
completion = await client.completions.create(**params)
|
||||
if prompt_logprobs is not None:
|
||||
assert completion.choices[0].prompt_logprobs is not None
|
||||
assert len(completion.choices[0].prompt_logprobs) > 0
|
||||
|
||||
assert completion.choices[1].prompt_logprobs is not None
|
||||
assert len(completion.choices[1].prompt_logprobs) > 0
|
||||
|
||||
else:
|
||||
assert completion.choices[0].prompt_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
single_completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
single_output = single_completion.choices[0].text
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True)
|
||||
chunks: List[str] = []
|
||||
finish_reason_count = 0
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason_count += 1
|
||||
# finish reason should only return in last block
|
||||
assert finish_reason_count == 1
|
||||
assert chunk.choices[0].finish_reason == "length"
|
||||
assert chunk.choices[0].text
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
prompt = "What is the capital of France?"
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": False, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": False,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is None
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": False}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
False,
|
||||
})
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].finish_reason is None:
|
||||
assert chunk.usage is None
|
||||
else:
|
||||
assert chunk.usage is None
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=True, stream_options=
|
||||
# {"include_usage": True, "continuous_usage_stats": True}
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
"continuous_usage_stats":
|
||||
True,
|
||||
})
|
||||
async for chunk in stream:
|
||||
assert chunk.usage is not None
|
||||
assert chunk.usage.prompt_tokens > 0
|
||||
assert chunk.usage.completion_tokens > 0
|
||||
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
||||
chunk.usage.completion_tokens)
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
final_chunk = await stream.__anext__()
|
||||
assert final_chunk.usage is not None
|
||||
assert final_chunk.usage.prompt_tokens > 0
|
||||
assert final_chunk.usage.completion_tokens > 0
|
||||
assert final_chunk.usage.total_tokens == (
|
||||
final_chunk.usage.prompt_tokens +
|
||||
final_chunk.usage.completion_tokens)
|
||||
assert final_chunk.choices == []
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"include_usage": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"include_usage": True})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": None}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": None})
|
||||
|
||||
# Test stream=False, stream_options=
|
||||
# {"continuous_usage_stats": True}
|
||||
with pytest.raises(BadRequestError):
|
||||
await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=False,
|
||||
stream_options={"continuous_usage_stats": True})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
# test both text and token IDs
|
||||
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
|
||||
# test simple list
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||
# not necessary for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
async for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
completion = await client.completions.create(model=model_name,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
||||
logprobs = completion.choices[0].logprobs
|
||||
assert logprobs is not None
|
||||
assert len(logprobs.text_offset) > 5
|
||||
assert (len(logprobs.token_logprobs) > 5
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
392
tests/v1/sample/test_logprobs.py
Normal file
392
tests/v1/sample/test_logprobs.py
Normal file
@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from tests.v1.sample.utils import (
|
||||
assert_incr_detok_str_matches_non_incr_detok_str,
|
||||
compute_correct_cumulative_logprob, get_test_batch)
|
||||
from vllm import SamplingParams
|
||||
|
||||
from ...conftest import VllmRunner
|
||||
|
||||
MODEL = "meta-llama/Llama-3.2-1B"
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vllm_model(vllm_runner):
|
||||
with vllm_runner(
|
||||
MODEL,
|
||||
dtype=DTYPE,
|
||||
max_logprobs=7,
|
||||
# Very small number of batched tokens to ensure
|
||||
# that we test chunking.
|
||||
max_num_batched_tokens=16,
|
||||
max_num_seqs=16,
|
||||
max_model_len=128,
|
||||
enforce_eager=True,
|
||||
#TODO: enable this once we support it for
|
||||
# prompt logprobs.
|
||||
enable_prefix_caching=False,
|
||||
gpu_memory_utilization=0.5,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hf_model(hf_runner):
|
||||
with hf_runner(MODEL, dtype=DTYPE) as hf_model:
|
||||
yield hf_model
|
||||
|
||||
|
||||
def _repeat_logprob_config(
|
||||
test_prompts,
|
||||
logprob_prompt_logprob_list: List[Tuple],
|
||||
) -> List[Tuple]:
|
||||
"""Ensure each test prompt has a logprob config.
|
||||
|
||||
A logprob config specifies the optional (i.e.
|
||||
may-be-`None`) number of sample logprobs and
|
||||
the optional number of prompt logprobs.
|
||||
|
||||
If more test prompts than logprob configs are
|
||||
provided, the provided logprob configs are
|
||||
tiled to match the number of test prompts.
|
||||
|
||||
If fewer test prompts than logprob configs
|
||||
are provided, the list of logprob configs
|
||||
is truncated to match the number of test
|
||||
prompts.
|
||||
|
||||
Otherwise, the list of logprob configs
|
||||
is returned as-is.
|
||||
|
||||
Args:
|
||||
test_prompts: list of prompts under test
|
||||
logprob_prompt_logprob_list: list of
|
||||
(optional num sample logprob,
|
||||
optional num prompt logprob)
|
||||
tuples
|
||||
|
||||
Returns:
|
||||
List of
|
||||
(optional num sample logprob,optional num prompt logprob)
|
||||
tuples which is either identical to
|
||||
`logprob_prompt_logprob_list`, or else repeats
|
||||
`logprob_prompt_logprob_list` enough times to match the
|
||||
number of `test_prompts`, or else is truncated to match
|
||||
the number of `test_prompts`
|
||||
"""
|
||||
num_test_prompts = len(test_prompts)
|
||||
# Make sure there is a logprobs configuration for each test prompt
|
||||
logprob_prompt_logprob_list = list(
|
||||
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
|
||||
num_test_prompts))
|
||||
# Now the number of prompts should match the number of sample params combos
|
||||
assert num_test_prompts == len(logprob_prompt_logprob_list)
|
||||
return logprob_prompt_logprob_list
|
||||
|
||||
|
||||
def _test_case_get_logprobs_and_prompt_logprobs(
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: str,
|
||||
temperature: float,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
test_prompts = example_prompts
|
||||
|
||||
max_tokens = 5
|
||||
hf_outputs = hf_model.generate_greedy(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
hf_logprobs = hf_model.generate_greedy_logprobs(
|
||||
test_prompts,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Batch has mixed sample params
|
||||
# (different logprobs/prompt logprobs combos)
|
||||
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
|
||||
|
||||
# Ensure that each test prompt has a logprob config for testing
|
||||
logprob_prompt_logprob_list = _repeat_logprob_config(
|
||||
test_prompts, logprob_prompt_logprob_list)
|
||||
# Generate SamplingParams
|
||||
vllm_sampling_params = [
|
||||
SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=num_lp,
|
||||
prompt_logprobs=num_plp,
|
||||
temperature=temperature,
|
||||
seed=1984)
|
||||
for num_lp, num_plp in logprob_prompt_logprob_list
|
||||
]
|
||||
|
||||
vllm_results = vllm_model.model.generate(
|
||||
test_prompts, sampling_params=vllm_sampling_params)
|
||||
|
||||
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
|
||||
vllm_results, hf_logprobs, hf_outputs,
|
||||
logprob_prompt_logprob_list):
|
||||
|
||||
# Extract request-level (prompt)logprobs config
|
||||
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
|
||||
|
||||
# Test whether sampled token output is consistent between vLLM and HF
|
||||
# vLLM prompt+completion should match HF output
|
||||
if temperature == 0.0:
|
||||
assert (vllm_result.prompt_token_ids +
|
||||
vllm_result.outputs[0].token_ids == hf_output[0])
|
||||
else:
|
||||
# Sampled tokens won't match if not greedy
|
||||
assert (vllm_result.prompt_token_ids == hf_output[0]
|
||||
[:len(vllm_result.prompt_token_ids)])
|
||||
|
||||
# Validate sample logprobs
|
||||
if num_top_logprobs is not None:
|
||||
assert num_top_logprobs is not None
|
||||
# Confirm that the structure of the sample logprobs in the result is
|
||||
# correct
|
||||
assert vllm_result.outputs[0].logprobs is not None
|
||||
assert len(vllm_result.outputs[0].logprobs) == max_tokens
|
||||
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
|
||||
vllm_result.outputs[0].token_ids):
|
||||
assert logprobs is not None
|
||||
|
||||
# Confirm that the output token appears among the logprobs
|
||||
assert token_id in logprobs
|
||||
token_in_topk = logprobs[token_id].rank <= num_top_logprobs
|
||||
|
||||
# If the output token is not included in the top K
|
||||
# logprob, it can return 1 more data
|
||||
if token_in_topk and num_top_logprobs != 0:
|
||||
assert len(logprobs) == num_top_logprobs
|
||||
else:
|
||||
assert len(logprobs) == num_top_logprobs + 1
|
||||
|
||||
if num_top_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_logprobs + 1))
|
||||
|
||||
output_text = vllm_result.outputs[0].text
|
||||
output_string_from_most_likely_tokens_lst: List[str] = []
|
||||
for top_logprobs in vllm_result.outputs[0].logprobs:
|
||||
top_logprob = next(iter(top_logprobs.values()))
|
||||
output_string_from_most_likely_tokens_lst.append(
|
||||
top_logprob.decoded_token)
|
||||
|
||||
output_string_from_most_likely_tokens = "".join(
|
||||
output_string_from_most_likely_tokens_lst)
|
||||
assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
output_text, output_string_from_most_likely_tokens,
|
||||
"The output text from the top logprob for each token "
|
||||
"position should be the same as the output text in the "
|
||||
"result.")
|
||||
|
||||
# Compare vLLM sample logprobs to HF
|
||||
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
|
||||
for i, top_logprobs in enumerate(vllm_sample_logprobs):
|
||||
for token_id, sample_logprob in top_logprobs.items():
|
||||
if temperature == 0.0 or i == 0:
|
||||
logprob = sample_logprob.logprob
|
||||
torch.testing.assert_close(
|
||||
logprob,
|
||||
hf_logprob[i][-1][token_id].item(),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
assert isinstance(
|
||||
sample_logprob.decoded_token,
|
||||
str), ("The token should be decoded by the time it is"
|
||||
" returned to the user.")
|
||||
|
||||
# At this point we know the sample logprobs are correct for this
|
||||
# request. Validate that cumulative_logprob is actually the sum.
|
||||
# For each request, assert that the returned cumulative logprob
|
||||
# matches the correct value, which is computed below.
|
||||
torch.testing.assert_close(
|
||||
vllm_result.outputs[0].cumulative_logprob,
|
||||
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
|
||||
atol=1e-6,
|
||||
rtol=1e-6)
|
||||
else:
|
||||
# Logprobs disabled for this request; should be None
|
||||
assert vllm_result.outputs[0].logprobs is None
|
||||
|
||||
# Validate prompt logprobs
|
||||
if num_top_prompt_logprobs is not None:
|
||||
# Confirm that structure of prompt logprobs in result is correct
|
||||
assert vllm_result.prompt_logprobs is not None
|
||||
# - The first prompt logprob is always None
|
||||
assert vllm_result.prompt_logprobs[0] is None
|
||||
# - Prompt logprobs are returned for all indices in
|
||||
# the prompt
|
||||
assert len(vllm_result.prompt_logprobs) == len(
|
||||
vllm_result.prompt_token_ids)
|
||||
for prompt_logprobs, prompt_token_id in zip(
|
||||
vllm_result.prompt_logprobs[1:],
|
||||
vllm_result.prompt_token_ids[1:]):
|
||||
assert prompt_logprobs is not None
|
||||
|
||||
# Confirm that the prompt token appears among the logprobs
|
||||
assert prompt_token_id in prompt_logprobs
|
||||
token_in_topk = prompt_logprobs[
|
||||
prompt_token_id].rank <= num_top_prompt_logprobs
|
||||
|
||||
# If the prompt token is not included in the top K
|
||||
# logprob, it can return 1 more data
|
||||
if token_in_topk and num_top_prompt_logprobs != 0:
|
||||
assert len(prompt_logprobs) == num_top_prompt_logprobs
|
||||
else:
|
||||
assert len(prompt_logprobs) == num_top_prompt_logprobs + 1
|
||||
|
||||
if num_top_prompt_logprobs > 0:
|
||||
# We should have an entry for each of the topk ranks
|
||||
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
|
||||
assert all(r in all_ranks
|
||||
for r in range(1, num_top_prompt_logprobs + 1))
|
||||
|
||||
# Compare prompt logprobs to HF
|
||||
# The first prompt logprob is always None, so we compare it from
|
||||
# 1:.
|
||||
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
|
||||
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
|
||||
for token_id, logprob in vllm_prompt_logprob_dict.items():
|
||||
torch.testing.assert_close(
|
||||
logprob.logprob,
|
||||
hf_logprob[0][i][token_id].item(),
|
||||
atol=2e-2,
|
||||
rtol=2e-2)
|
||||
else:
|
||||
assert vllm_result.prompt_logprobs is None
|
||||
|
||||
|
||||
#@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("batch_logprobs_composition",
|
||||
["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"])
|
||||
@pytest.mark.parametrize("temperature", [0.0, 2.0])
|
||||
def test_get_logprobs_and_prompt_logprobs(
|
||||
hf_model,
|
||||
vllm_model,
|
||||
batch_logprobs_composition: str,
|
||||
temperature: float,
|
||||
example_prompts,
|
||||
) -> None:
|
||||
"""Test V1 Engine logprobs & prompt logprobs
|
||||
|
||||
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
|
||||
settings and validate that
|
||||
* The generated logprobs and prompt logprobs are consistent with the
|
||||
configuration settings, in terms of whether or not the logprobs
|
||||
(of either type) were requested and how many were requested
|
||||
* The generated logprobs are consistent with the generated tokens
|
||||
* The generated (prompt)logprobs are consistent with HuggingFace
|
||||
(prompt)logprobs, as a reference
|
||||
|
||||
batch_logprobs_composition controls the logprobs configurations for
|
||||
requests in the batch under test.
|
||||
|
||||
Args:
|
||||
hf_model
|
||||
vllm_model
|
||||
batch_logprobs_composition: logprobs configuration for test batch
|
||||
example_prompts
|
||||
monkeypatch
|
||||
"""
|
||||
_test_case_get_logprobs_and_prompt_logprobs(
|
||||
hf_model=hf_model,
|
||||
vllm_model=vllm_model,
|
||||
batch_logprobs_composition=batch_logprobs_composition,
|
||||
temperature=temperature,
|
||||
example_prompts=example_prompts)
|
||||
|
||||
|
||||
def test_max_logprobs(monkeypatch):
|
||||
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
|
||||
|
||||
Should also fail for `prompt_logprobs > max_logprobs`
|
||||
|
||||
Args:
|
||||
monkeypatch
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, "FLASH_ATTN")
|
||||
|
||||
runner = VllmRunner("facebook/opt-125m",
|
||||
max_logprobs=1,
|
||||
enable_prefix_caching=False,
|
||||
max_model_len=256)
|
||||
vllm_sampling_params = SamplingParams(logprobs=1)
|
||||
# should pass
|
||||
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
|
||||
|
||||
bad_sampling_params = SamplingParams(logprobs=2)
|
||||
with pytest.raises(ValueError):
|
||||
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
|
||||
|
||||
|
||||
def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
|
||||
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
monkeypatch: supports editing env vars and rolling back changes
|
||||
after the test
|
||||
"""
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs=None,
|
||||
temperature=0.0)
|
||||
results_logprobs_none = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_none)
|
||||
|
||||
for i in range(len(results_logprobs_none)):
|
||||
# Check sample logprobs are None
|
||||
assert results_logprobs_none[i].outputs[0].logprobs is None
|
||||
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
|
||||
# Check prompt logprobs are None
|
||||
assert results_logprobs_none[i].prompt_logprobs is None
|
||||
|
||||
|
||||
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch):
|
||||
"""Engine should return sampled token and prompt token logprobs
|
||||
|
||||
Args:
|
||||
vllm_model: vLLM model fixture
|
||||
example_prompts: list of example prompts (test fixture)
|
||||
monkeypatch: supports editing env vars and rolling back changes
|
||||
after the test
|
||||
"""
|
||||
max_tokens = 5
|
||||
|
||||
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
|
||||
logprobs=0,
|
||||
prompt_logprobs=0,
|
||||
temperature=0.0)
|
||||
results_logprobs_zero = vllm_model.model.generate(
|
||||
example_prompts, sampling_params=sampling_params_logprobs_zero)
|
||||
|
||||
for i in range(len(results_logprobs_zero)):
|
||||
# Check that there is one sample logprob dict for each
|
||||
# sample token
|
||||
logprobs = results_logprobs_zero[i].outputs[0].logprobs
|
||||
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
|
||||
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
|
||||
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
|
||||
assert logprobs is not None
|
||||
assert len(sampled_token_ids) == len(logprobs)
|
||||
assert results_logprobs_zero[i].outputs[
|
||||
0].cumulative_logprob is not None
|
||||
# Check that there is one prompt logprob dict for each
|
||||
# prompt token
|
||||
assert prompt_logprobs is not None
|
||||
assert len(prompt_token_ids) == len(prompt_logprobs)
|
||||
52
tests/v1/sample/test_logprobs_e2e.py
Normal file
52
tests/v1/sample/test_logprobs_e2e.py
Normal file
@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import lm_eval
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
# arc-easy uses prompt_logprobs=1, logprobs=1
|
||||
TASK = "arc_easy"
|
||||
FILTER = "acc_norm,none"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.62
|
||||
|
||||
# FIXME(rob): enable prefix caching once supported.
|
||||
MODEL = "meta-llama/Llama-3.2-1B"
|
||||
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
|
||||
SERVER_ARGS = [
|
||||
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
|
||||
]
|
||||
NUM_CONCURRENT = 100
|
||||
|
||||
|
||||
def test_prompt_logprobs_e2e():
|
||||
results = lm_eval.simple_evaluate(model="vllm",
|
||||
model_args=MODEL_ARGS,
|
||||
tasks=TASK,
|
||||
batch_size="auto")
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
|
||||
|
||||
def test_promt_logprobs_e2e_server():
|
||||
with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server:
|
||||
url = f"{remote_server.url_for('v1')}/completions"
|
||||
|
||||
model_args = (
|
||||
f"model={MODEL},"
|
||||
f"base_url={url},"
|
||||
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
|
||||
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="local-completions",
|
||||
model_args=model_args,
|
||||
tasks=TASK,
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
120
tests/v1/sample/utils.py
Normal file
120
tests/v1/sample/utils.py
Normal file
@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from vllm import CompletionOutput
|
||||
|
||||
|
||||
def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]:
|
||||
"""Generate logprobs configs for a batch of requests
|
||||
|
||||
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
|
||||
num_prompt_logprobs. The batch logprobs configuration is the list of request
|
||||
logprobs configs.
|
||||
|
||||
batch_logprobs_composition == "NONE" yields a batch with no sample or prompt
|
||||
logprobs
|
||||
|
||||
batch_logprobs_composition == "SAMPLE" yields a batch with some requests
|
||||
configured for sample logprobs only, and others configured for no logprobs
|
||||
|
||||
batch_logprobs_composition == "PROMPT" yields a batch with some requests
|
||||
configured for prompt logprobs only, and others configured for no logprobs
|
||||
|
||||
batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some
|
||||
requests configured for sample logprobs and prompt logprobs, some configured
|
||||
for only sample logprobs or only prompt logprobs, and some configured for
|
||||
no logprobs
|
||||
|
||||
Args:
|
||||
batch_logprobs_composition: types of logprobs configs to include in batch
|
||||
|
||||
Returns:
|
||||
|
||||
List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
|
||||
tuples
|
||||
"""
|
||||
if batch_logprobs_composition == "NONE":
|
||||
# No requests with sample or prompt logprobs
|
||||
return [(None, None)]
|
||||
elif batch_logprobs_composition == "SAMPLE":
|
||||
# Requests requiring sample logprobs or no logprobs
|
||||
return [
|
||||
(None, None),
|
||||
(0, None),
|
||||
(5, None),
|
||||
(3, None),
|
||||
]
|
||||
elif batch_logprobs_composition == "PROMPT":
|
||||
# Requests requiring prompt logprobs or no logprobs
|
||||
return [
|
||||
(None, None),
|
||||
(None, 0),
|
||||
(None, 6),
|
||||
(None, 5),
|
||||
]
|
||||
elif batch_logprobs_composition == "SAMPLE_PROMPT":
|
||||
# Requests requiring either no logprobs, just
|
||||
# sample logprobs, just prompt logprobs, or
|
||||
# both sample and prompt logprobs
|
||||
return [
|
||||
(None, None),
|
||||
(0, None),
|
||||
(5, None),
|
||||
(3, None),
|
||||
(0, 3),
|
||||
(6, 0),
|
||||
(6, 3),
|
||||
(None, 6),
|
||||
(None, 5),
|
||||
(None, 0),
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid logprobs batch configuration for test.")
|
||||
|
||||
|
||||
def assert_incr_detok_str_matches_non_incr_detok_str(
|
||||
incremental_detokenization_str: str,
|
||||
non_incremental_detokenization_str: str,
|
||||
msg: str,
|
||||
) -> None:
|
||||
"""Compare incrementally detok. text to non-incrementally detok. text
|
||||
|
||||
Fail if the strings mismatch after non-alphanumeric characters are stripped
|
||||
out.
|
||||
|
||||
Rationale: incremental detokenization in the text generation process allows
|
||||
the tokenizer to adjust the next token text output based on the token's
|
||||
context in the string. However, logprobs detokenization detokenizes each
|
||||
token individually, and the resultant strings may include some
|
||||
non-alphanumeric placeholder characters where there could be i.e.
|
||||
whitespace. So, this function compares only the alphanumeric text
|
||||
between two strings and fails if there is a mismatch, which helps
|
||||
with validating logprobs detokenization.
|
||||
|
||||
Args:
|
||||
incremental_detokenization_str: incrementally-detokenized generated text
|
||||
non_incremental_detokenization_str: non-incrementally-detokenized logprob
|
||||
tokens
|
||||
msg: error message if `assert` fails
|
||||
"""
|
||||
rgx = r'[^a-zA-Z0-9]+'
|
||||
assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub(
|
||||
rgx, '', non_incremental_detokenization_str)), (msg)
|
||||
|
||||
|
||||
def compute_correct_cumulative_logprob(
|
||||
completion_output: CompletionOutput) -> float:
|
||||
"""Compute known-good value for evaluating cumulative logprob
|
||||
|
||||
Args:
|
||||
completion_output: completion output from engine
|
||||
|
||||
Returns:
|
||||
Known-good cumulative logprob value
|
||||
"""
|
||||
token_ids = completion_output.token_ids
|
||||
logprobs = completion_output.logprobs
|
||||
assert logprobs is not None
|
||||
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])
|
||||
@ -142,6 +142,9 @@ class RequestOutput:
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
cumulative_logprob: Optional[float],
|
||||
finished: bool = False,
|
||||
) -> "RequestOutput":
|
||||
"""Initialize a new RequestOutput object."""
|
||||
@ -151,15 +154,14 @@ class RequestOutput:
|
||||
index=0,
|
||||
text=text,
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=None,
|
||||
logprobs=None, # TODO
|
||||
)
|
||||
cumulative_logprob=cumulative_logprob,
|
||||
logprobs=logprobs)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=None, # TODO
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=[completion_output],
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
@ -74,6 +74,25 @@ def convert_prompt_ids_to_tokens(
|
||||
return new_tokens, prefix_offset, read_offset
|
||||
|
||||
|
||||
def convert_ids_list_to_tokens(
|
||||
tokenizer: AnyTokenizer,
|
||||
token_ids: List[int],
|
||||
) -> List[str]:
|
||||
"""Detokenize the input ids individually.
|
||||
|
||||
Args:
|
||||
tokenizer: tokenizer used by model under test
|
||||
token_ids: convert these tokens (Python list form)
|
||||
|
||||
Returns:
|
||||
Python list of token string representations
|
||||
|
||||
"""
|
||||
token_str_lst = tokenizer.convert_ids_to_tokens(token_ids)
|
||||
_replace_none_with_empty(token_str_lst) # type: ignore
|
||||
return token_str_lst
|
||||
|
||||
|
||||
# Based on
|
||||
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
|
||||
# under Apache 2.0 license
|
||||
|
||||
@ -437,6 +437,8 @@ class Scheduler:
|
||||
) -> EngineCoreOutputs:
|
||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
new_running: List[Request] = []
|
||||
outputs: List[EngineCoreOutput] = []
|
||||
@ -471,6 +473,13 @@ class Scheduler:
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
|
||||
stopped = False
|
||||
new_logprobs = None
|
||||
new_token_ids = None
|
||||
|
||||
if request.num_computed_tokens == request.num_tokens:
|
||||
req_index = model_runner_output.req_id_to_index[req_id]
|
||||
# NOTE(woosuk): Currently, we assume that each request
|
||||
@ -486,20 +495,30 @@ class Scheduler:
|
||||
if stopped:
|
||||
self._free_request(request)
|
||||
|
||||
# Extract sample logprobs if needed.
|
||||
if request.sampling_params.logprobs is not None:
|
||||
assert logprobs is not None
|
||||
# NOTE: once we support N tokens per step (spec decode),
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
new_token_ids = request.output_token_ids[-num_new_tokens:]
|
||||
|
||||
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
||||
if new_token_ids or prompt_logprobs_tensors is not None:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
output = EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=request.output_token_ids[-num_new_tokens:],
|
||||
finished=request.is_finished(),
|
||||
finish_reason=request.get_finished_reason(),
|
||||
stop_reason=request.stop_reason)
|
||||
outputs.append(output)
|
||||
outputs.append(
|
||||
EngineCoreOutput(
|
||||
request_id=req_id,
|
||||
new_token_ids=new_token_ids or [],
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason))
|
||||
|
||||
# Breakout of the loop.
|
||||
if stopped:
|
||||
continue
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
|
||||
new_running.append(request)
|
||||
self.running = new_running
|
||||
return EngineCoreOutputs(
|
||||
outputs=outputs,
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
import msgspec
|
||||
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -67,10 +68,17 @@ class EngineCoreOutput(
|
||||
|
||||
request_id: str
|
||||
new_token_ids: List[int]
|
||||
finished: bool
|
||||
|
||||
new_logprobs: Optional[LogprobsLists] = None
|
||||
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
|
||||
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import List, Tuple, Type
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from msgspec import msgpack
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -26,7 +25,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -292,7 +291,7 @@ class EngineCoreProc(EngineCore):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = msgpack.Encoder()
|
||||
encoder = MsgpackEncoder()
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
|
||||
|
||||
@ -7,7 +7,6 @@ import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
@ -20,7 +19,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
|
||||
from vllm.v1.utils import BackgroundProcHandle
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -163,7 +162,7 @@ class MPClient(EngineCoreClient):
|
||||
|
||||
# Serialization setup.
|
||||
self.encoder = PickleEncoder()
|
||||
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
|
||||
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
# ZMQ setup.
|
||||
self.ctx = (
|
||||
|
||||
@ -1,27 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetokenizerOutput:
|
||||
output_text: str
|
||||
token_ids: List[int]
|
||||
finished: bool
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalDetokenizer:
|
||||
|
||||
@ -42,7 +32,6 @@ class IncrementalDetokenizer:
|
||||
# Parameters for detokenization
|
||||
skip_special_tokens: bool
|
||||
spaces_between_special_tokens: bool
|
||||
output_kind: RequestOutputKind
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
@ -90,25 +79,19 @@ class IncrementalDetokenizer:
|
||||
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
||||
spaces_between_special_tokens=request.sampling_params.
|
||||
spaces_between_special_tokens,
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
prompt_len=len(request.prompt_token_ids),
|
||||
tokenizer=tokenizer,
|
||||
stop_buffer_length=stop_buffer_length,
|
||||
)
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
output: EngineCoreOutput,
|
||||
) -> Optional[DetokenizerOutput]:
|
||||
def update(self, new_token_ids: List[int]) -> Optional[str]:
|
||||
"""
|
||||
Update RequestState for the request_id by:
|
||||
1) Detokenize the new token ids incrementally.
|
||||
2) Update the RequestOutput with the new text.
|
||||
"""
|
||||
2) Evaluate stop criteria.
|
||||
|
||||
new_token_ids = output.new_token_ids
|
||||
finish_reason = output.finish_reason
|
||||
stop_reason = output.stop_reason
|
||||
Return matched stop string or None.
|
||||
"""
|
||||
|
||||
# 1) Detokenize the new token ids incrementally.
|
||||
# TODO(woosuk): This method becomes very inefficient when the number of
|
||||
@ -131,11 +114,13 @@ class IncrementalDetokenizer:
|
||||
self.tokens.extend(new_tokens)
|
||||
self.prefix_offset = prefix_offset
|
||||
self.read_offset = read_offset
|
||||
self.output_text += new_decoded_token_text
|
||||
|
||||
decoded_text += new_decoded_token_text
|
||||
|
||||
self.output_text += decoded_text
|
||||
|
||||
# 2) Evaluate stop criteria.
|
||||
stop_string = None
|
||||
if self.stop:
|
||||
stop = StopChecker.check_stop_strings(
|
||||
output_text=self.output_text,
|
||||
@ -144,28 +129,13 @@ class IncrementalDetokenizer:
|
||||
include_in_output=self.include_stop_str_in_output,
|
||||
)
|
||||
if stop is not None:
|
||||
stop_str, truncate_to = stop
|
||||
stop_string, truncate_to = stop
|
||||
if truncate_to != -1:
|
||||
self.output_text = self.output_text[:truncate_to]
|
||||
finish_reason = FinishReason.STOP
|
||||
stop_reason = stop_str
|
||||
|
||||
# TODO: handle stop_token_ids here too?
|
||||
return stop_string
|
||||
|
||||
# 3) Update the RequestOutput object with the new text.
|
||||
finished = finish_reason is not None
|
||||
if self.output_kind == RequestOutputKind.FINAL_ONLY \
|
||||
and not finished:
|
||||
return None
|
||||
|
||||
delta = self.output_kind == RequestOutputKind.DELTA
|
||||
output_text = self._get_next_output_text(finished, delta)
|
||||
token_ids = new_token_ids if delta else self.output_token_ids
|
||||
|
||||
return DetokenizerOutput(output_text, token_ids, finished,
|
||||
finish_reason, stop_reason)
|
||||
|
||||
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
||||
"""If delta is True, only new text since the last call to
|
||||
this method is returned"""
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ class LLMEngine:
|
||||
multiprocess_mode: bool = False,
|
||||
) -> None:
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
|
||||
194
vllm/v1/engine/logprobs.py
Normal file
194
vllm/v1/engine/logprobs.py
Normal file
@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_ids_list_to_tokens)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogprobsProcessor:
|
||||
|
||||
# Tokenizer for this request
|
||||
tokenizer: AnyTokenizer
|
||||
|
||||
# Logprobs for this request
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
prompt_logprobs: Optional[PromptLogprobs]
|
||||
cumulative_logprob: Optional[float]
|
||||
num_logprobs: Optional[int]
|
||||
num_prompt_logprobs: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: EngineCoreRequest,
|
||||
) -> "LogprobsProcessor":
|
||||
num_logprobs = request.sampling_params.logprobs
|
||||
num_prompt_logprobs = request.sampling_params.prompt_logprobs
|
||||
return cls(
|
||||
tokenizer=tokenizer,
|
||||
cumulative_logprob=(None if num_logprobs is None else 0.),
|
||||
logprobs=(None if num_logprobs is None else []),
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
|
||||
num_prompt_logprobs=num_prompt_logprobs,
|
||||
num_logprobs=num_logprobs,
|
||||
)
|
||||
|
||||
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
|
||||
"""Update with sample logprobs from EngineCore.
|
||||
|
||||
Outer lists are only of len > 1 if EngineCore made
|
||||
>1 tokens in prior step (e.g. in spec decoding).
|
||||
|
||||
Args:
|
||||
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
|
||||
|
||||
"""
|
||||
|
||||
assert self.num_logprobs is not None
|
||||
assert self.logprobs is not None
|
||||
assert self.cumulative_logprob is not None
|
||||
|
||||
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
|
||||
|
||||
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
|
||||
token_ids_lst):
|
||||
|
||||
# Detokenize (non-incrementally).
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer, token_ids)
|
||||
|
||||
# Sampler puts the sampled logprob in first.
|
||||
sampled_token_logprob = logprobs[0]
|
||||
self.cumulative_logprob += sampled_token_logprob
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.logprobs.append(
|
||||
self._make_logprob_dict(
|
||||
logprobs,
|
||||
token_ids,
|
||||
decoded_tokens,
|
||||
rank,
|
||||
self.num_logprobs,
|
||||
))
|
||||
|
||||
def _update_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
) -> None:
|
||||
"""Update with prompt logprobs from EngineCore.
|
||||
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
|
||||
"""
|
||||
|
||||
# Prompt logprobs are enabled.
|
||||
assert self.num_prompt_logprobs is not None
|
||||
assert self.prompt_logprobs is not None
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = convert_ids_list_to_tokens(
|
||||
self.tokenizer,
|
||||
token_ids.flatten().tolist())
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the torch tensors.
|
||||
# TODO(rob): experiment with doing this in EngineCore?
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
self.prompt_logprobs.append(
|
||||
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
self.num_prompt_logprobs))
|
||||
|
||||
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
|
||||
"""Pop and return all request prompt logprobs
|
||||
|
||||
The logprobs processor aggregates prompt chunk logprobs
|
||||
over one or more prefill chunks. This method returns
|
||||
all prompt logprobs at once and then forgets them.
|
||||
Ensures correct RequestOutputKind.DELTA semantics
|
||||
wherein all prompt logprobs are returned at once at
|
||||
the end of prefill.
|
||||
|
||||
Returns:
|
||||
None if prompt logprobs are disabled for this request.
|
||||
List of all prompt logprobs, otherwise.
|
||||
"""
|
||||
plp = self.prompt_logprobs
|
||||
if plp:
|
||||
self.prompt_logprobs = []
|
||||
return plp
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: List[float],
|
||||
logprob_token_ids: List[int],
|
||||
decoded_tokens: List[str],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> Dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
|
||||
Returns:
|
||||
Dict[token id, Logprob]
|
||||
"""
|
||||
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank, ), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
def update_from_output(self, output: EngineCoreOutput) -> None:
|
||||
if output.new_logprobs is not None:
|
||||
self._update_sample_logprobs(output.new_logprobs)
|
||||
if output.new_prompt_logprobs_tensors is not None:
|
||||
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)
|
||||
@ -5,11 +5,12 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
|
||||
IncrementalDetokenizer)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
|
||||
|
||||
|
||||
@ -26,16 +27,20 @@ class RequestState:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
logprobs_processor: LogprobsProcessor,
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
arrival_time: float,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
@ -51,8 +56,13 @@ class RequestState:
|
||||
) -> "RequestState":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
output_kind=request.sampling_params.output_kind,
|
||||
prompt=request.prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
logprobs_processor=LogprobsProcessor.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
),
|
||||
detokenizer=IncrementalDetokenizer.from_new_request(
|
||||
tokenizer=tokenizer,
|
||||
request=request,
|
||||
@ -127,13 +137,8 @@ class OutputProcessor:
|
||||
batch to ensure system overheads are minimized. This is the
|
||||
only function that should loop over EngineCoreOutputs.
|
||||
|
||||
If you need to touch every element of the batch, implement a
|
||||
method called XXXClass.update_from_output() to be called
|
||||
within the loop below. For examples, see:
|
||||
* IterationStats.update_from_output()
|
||||
* Detokenizer.update_from_output()
|
||||
|
||||
TODO(rob): add Protocol makes update_from_output explicit.
|
||||
If you need to touch every element of the batch, do it from
|
||||
within the loop below.
|
||||
|
||||
**********************************************************
|
||||
"""
|
||||
@ -154,17 +159,37 @@ class OutputProcessor:
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats)
|
||||
req_state.is_prefilling = False
|
||||
|
||||
# 2) Detokenize the token ids into text.
|
||||
detokenizer_output = req_state.detokenizer.update_from_output(
|
||||
engine_core_output)
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
|
||||
# 3) Create and handle RequestOutput objects.
|
||||
if detokenizer_output is not None:
|
||||
request_output = self._make_request_output(
|
||||
req_state, detokenizer_output)
|
||||
# TODO(andy): prompt logprobs + chunked prefill can
|
||||
# result in engine core returning an output for a
|
||||
# partial prefill (in order to send back partial
|
||||
# prompt logprobs.) This breaks the invariant that
|
||||
# process_outputs is only operating on engine core
|
||||
# outputs associated with non-partial completions.
|
||||
# Currently this is handled by having `is_prefilling`
|
||||
# check for new decoded tokens, indicating that
|
||||
# the completion is not partial.
|
||||
#
|
||||
# Follow up will aggregate partial prompt logprobs
|
||||
# in the EngineCore.
|
||||
req_state.is_prefilling = not new_token_ids
|
||||
|
||||
# 2) Detokenize the token ids into text and check for stop
|
||||
# strings.
|
||||
stop_reason = req_state.detokenizer.update(new_token_ids)
|
||||
if stop_reason:
|
||||
finish_reason = FinishReason.STOP
|
||||
|
||||
# 3) Compute sample and prompt logprobs for request,
|
||||
# if required.
|
||||
req_state.logprobs_processor.update_from_output(engine_core_output)
|
||||
|
||||
# 4) Create and handle RequestOutput objects.
|
||||
if request_output := self._make_request_output(
|
||||
req_state, new_token_ids, finish_reason, stop_reason):
|
||||
if req_state.queue is not None:
|
||||
# AsyncLLM: put into queue for handling by generate().
|
||||
req_state.queue.put_nowait(request_output)
|
||||
@ -174,18 +199,16 @@ class OutputProcessor:
|
||||
|
||||
# Free completed requests.
|
||||
if request_output.finished:
|
||||
assert detokenizer_output.finish_reason is not None
|
||||
|
||||
self.request_states.pop(req_id)
|
||||
if not engine_core_output.finished:
|
||||
# If req not finished in EngineCore, but Detokenizer
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats
|
||||
# Track per-request stats.
|
||||
assert finish_reason is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
detokenizer_output.finish_reason, request_output,
|
||||
req_state.stats)
|
||||
finish_reason, request_output, req_state.stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
@ -196,20 +219,47 @@ class OutputProcessor:
|
||||
@staticmethod
|
||||
def _make_request_output(
|
||||
request_state: RequestState,
|
||||
detokenizer_output: DetokenizerOutput,
|
||||
) -> RequestOutput:
|
||||
new_token_ids: List[int],
|
||||
finish_reason: Optional[FinishReason],
|
||||
stop_reason: Optional[str],
|
||||
) -> Optional[RequestOutput]:
|
||||
|
||||
finished = finish_reason is not None
|
||||
output_kind = request_state.output_kind
|
||||
# In follow up, we will switch to invariant where EngineCore
|
||||
# does not stream partial prefills.
|
||||
if not finished and (request_state.is_prefilling
|
||||
or output_kind == RequestOutputKind.FINAL_ONLY):
|
||||
# Only the final output is required in FINAL_ONLY mode.
|
||||
return None
|
||||
|
||||
detokenizer = request_state.detokenizer
|
||||
logprobs_processor = request_state.logprobs_processor
|
||||
|
||||
delta = output_kind == RequestOutputKind.DELTA
|
||||
logprobs = logprobs_processor.logprobs
|
||||
if delta:
|
||||
if logprobs:
|
||||
logprobs = logprobs[-len(new_token_ids):]
|
||||
# Side effect: logprobs processor forgets prompt logprobs
|
||||
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
|
||||
else:
|
||||
prompt_logprobs = logprobs_processor.prompt_logprobs
|
||||
|
||||
request_output = RequestOutput.new(
|
||||
request_state.request_id,
|
||||
request_state.prompt,
|
||||
request_state.prompt_token_ids,
|
||||
detokenizer_output.output_text,
|
||||
detokenizer_output.token_ids,
|
||||
detokenizer_output.finished,
|
||||
request_id=request_state.request_id,
|
||||
prompt=request_state.prompt,
|
||||
prompt_token_ids=request_state.prompt_token_ids,
|
||||
text=detokenizer.get_next_output_text(finished, delta),
|
||||
token_ids=new_token_ids if delta else detokenizer.output_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
cumulative_logprob=logprobs_processor.cumulative_logprob,
|
||||
finished=finished,
|
||||
)
|
||||
if detokenizer_output.finished:
|
||||
if finished:
|
||||
completion_output = request_output.outputs[0]
|
||||
completion_output.finish_reason = str(
|
||||
detokenizer_output.finish_reason)
|
||||
completion_output.stop_reason = detokenizer_output.stop_reason
|
||||
completion_output.finish_reason = str(finish_reason)
|
||||
completion_output.stop_reason = stop_reason
|
||||
|
||||
return request_output
|
||||
|
||||
@ -33,6 +33,7 @@ class Processor:
|
||||
):
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@ -51,6 +52,37 @@ class Processor:
|
||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||
cache_config.enable_prefix_caching
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> None:
|
||||
if not isinstance(params, SamplingParams):
|
||||
return
|
||||
|
||||
max_logprobs = self.model_config.max_logprobs
|
||||
# Validate sample logprobs.
|
||||
if params.logprobs and params.logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested sample logprobs of {params.logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
# Validate prompt logprobs.
|
||||
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Requested prompt logprobs of {params.prompt_logprobs}, "
|
||||
f"which is greater than max allowed: {max_logprobs}")
|
||||
|
||||
# TODO(andy): enable this in follow up by recomputing.
|
||||
if (params.prompt_logprobs is not None
|
||||
and self.cache_config.enable_prefix_caching):
|
||||
raise ValueError("Prefix caching with prompt logprobs not yet "
|
||||
"supported on VLLM V1.")
|
||||
|
||||
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
@ -64,12 +96,11 @@ class Processor:
|
||||
) -> EngineCoreRequest:
|
||||
|
||||
# TODO(woosuk): Support pooling models.
|
||||
# TODO(woosuk): Check max_logprobs
|
||||
# TODO(woosuk): Support encoder-decoder models.
|
||||
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
self._validate_logprobs(params)
|
||||
self._validate_lora(lora_request)
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
|
||||
@ -60,14 +60,17 @@ class IterationStats:
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling:
|
||||
# This relies on the invariant that EngineCore does
|
||||
# not stream outputs for partially completed prefills
|
||||
# (scheduler.update_from_output makes EngineCoreOutput
|
||||
# iff num_computed_tokens == num_tokens).
|
||||
assert (num_new_generation_tokens > 0)
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
self.time_to_first_tokens_iter.append(last_token_latency)
|
||||
# TODO(andy): we used to assert that num_new_generation_tokens
|
||||
# > 0 with an invariant that EngineCore does not stream outputs
|
||||
# for partially completed prefills (scheduler.update_from_output
|
||||
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
|
||||
# When prompt logprobs are enabled, we currently stream out the
|
||||
# partially completed prompt.
|
||||
# This will be reverted in a follow up PR and we should re-enable
|
||||
# this assertion / invariant.
|
||||
if num_new_generation_tokens > 0:
|
||||
self.num_prompt_tokens += prompt_len
|
||||
self.time_to_first_tokens_iter.append(last_token_latency)
|
||||
else:
|
||||
self.time_per_output_tokens_iter.append(last_token_latency)
|
||||
|
||||
|
||||
@ -1,25 +1,51 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LogprobsLists(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: List[List[int]]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: List[List[float]]
|
||||
# [num_reqs]
|
||||
sampled_token_ranks: List[int]
|
||||
|
||||
def slice(self, start: int, end: int):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids[start:end],
|
||||
self.logprobs[start:end],
|
||||
self.sampled_token_ranks[start:end],
|
||||
)
|
||||
|
||||
|
||||
class LogprobsTensors(NamedTuple):
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: torch.Tensor
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: torch.Tensor
|
||||
# [num_reqs]
|
||||
selected_token_ranks: torch.Tensor
|
||||
|
||||
def tolists(self):
|
||||
return LogprobsLists(
|
||||
self.logprob_token_ids.tolist(),
|
||||
self.logprobs.tolist(),
|
||||
self.selected_token_ranks.tolist(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplerOutput:
|
||||
|
||||
# [num_reqs]
|
||||
sampled_token_ids: torch.Tensor
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids: Optional[torch.Tensor]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs: Optional[torch.Tensor]
|
||||
|
||||
# TODO: Support prompt logprobs.
|
||||
prompt_logprob_token_ids: Optional[torch.Tensor]
|
||||
prompt_logprobs: Optional[torch.Tensor]
|
||||
logprobs_tensors: Optional[LogprobsTensors]
|
||||
|
||||
|
||||
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||
@ -36,6 +62,12 @@ class ModelRunnerOutput:
|
||||
sampled_token_ids: List[int]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprob_token_ids_cpu: Optional[torch.Tensor]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
logprobs_cpu: Optional[torch.Tensor]
|
||||
# [num_reqs]
|
||||
logprobs: Optional[LogprobsLists]
|
||||
|
||||
# req_id -> (token_ids, logprobs, ranks)
|
||||
# [prompt_len, num_prompt_logprobs]
|
||||
# [prompt_len, num_prompt_logprobs]
|
||||
# [prompt_len]
|
||||
prompt_logprobs_dict: Dict[str, LogprobsTensors]
|
||||
|
||||
@ -20,7 +20,8 @@ class SamplingMetadata:
|
||||
|
||||
generators: Dict[int, torch.Generator]
|
||||
|
||||
max_num_logprobs: int
|
||||
# None means no logprobs, 0 means sampled token logprobs only
|
||||
max_num_logprobs: Optional[int]
|
||||
|
||||
no_penalties: bool
|
||||
prompt_token_ids: Optional[torch.Tensor]
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""A layer that samples the next tokens from the model's outputs."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
|
||||
apply_min_token_penalties)
|
||||
@ -25,20 +24,16 @@ class Sampler(nn.Module):
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
needs_logprobs = sampling_metadata.max_num_logprobs > 0
|
||||
if needs_logprobs:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# NOTE: We compute logprobs first because the below ops may
|
||||
# modify the logits tensor in-place (and we don't want to clone
|
||||
# the logits tensor for memory efficiency).
|
||||
topk_logprobs, topk_indices = self.get_topk_logprobs(
|
||||
logits, sampling_metadata)
|
||||
else:
|
||||
topk_logprobs = None
|
||||
topk_indices = None
|
||||
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
# TODO(rob): provide option for logprobs post sampling.
|
||||
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
||||
num_logprobs = sampling_metadata.max_num_logprobs
|
||||
if num_logprobs is not None:
|
||||
raw_logprobs = self.compute_logprobs(logits)
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
@ -48,15 +43,19 @@ class Sampler(nn.Module):
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# Gather the logprobs of the topk and sampled token (if requested).
|
||||
# Get logprobs and rank tensors (if requested)
|
||||
logprobs_tensors = None if num_logprobs is None else \
|
||||
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
logprob_token_ids=topk_indices,
|
||||
logprobs=topk_logprobs,
|
||||
prompt_logprob_token_ids=None,
|
||||
prompt_logprobs=None,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
@ -103,19 +102,52 @@ class Sampler(nn.Module):
|
||||
)
|
||||
return sampled
|
||||
|
||||
def get_topk_logprobs(
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
# FIXME: Mask the sampled token_id, get topk logprobs,
|
||||
# and concatenate the topk with the sampled token_id.
|
||||
topk_logprobs, topk_indices = torch.topk(
|
||||
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logits: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
topk_indices = topk_indices.to(torch.int32)
|
||||
return topk_logprobs, topk_indices
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_penalties(
|
||||
self,
|
||||
|
||||
@ -1,12 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from msgspec import msgpack
|
||||
|
||||
CUSTOM_TYPE_CODE_PICKLE = 1
|
||||
|
||||
|
||||
class PickleEncoder:
|
||||
|
||||
def encode(self, obj):
|
||||
def encode(self, obj: Any):
|
||||
return pickle.dumps(obj)
|
||||
|
||||
def decode(self, data):
|
||||
def decode(self, data: Any):
|
||||
return pickle.loads(data)
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor serialization."""
|
||||
|
||||
def __init__(self):
|
||||
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
|
||||
|
||||
def encode(self, obj: Any) -> bytes:
|
||||
return self.encoder.encode(obj)
|
||||
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> None:
|
||||
self.encoder.encode_into(obj, buf)
|
||||
|
||||
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor serialization."""
|
||||
|
||||
def __init__(self, t: Any):
|
||||
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
|
||||
|
||||
def decode(self, obj: Any):
|
||||
return self.decoder.decode(obj)
|
||||
|
||||
|
||||
def custom_enc_hook(obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
# NOTE(rob): it is fastest to use numpy + pickle
|
||||
# when serializing torch tensors.
|
||||
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
|
||||
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
|
||||
|
||||
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
|
||||
|
||||
|
||||
def custom_ext_hook(code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_CODE_PICKLE:
|
||||
return torch.from_numpy(pickle.loads(data))
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
|
||||
@ -176,7 +176,9 @@ class InputBatch:
|
||||
self.generators: Dict[int, torch.Generator] = {}
|
||||
|
||||
self.num_logprobs: Dict[str, int] = {}
|
||||
self.prompt_logprob_reqs: Set[str] = set()
|
||||
# NOTE(rob): num_prompt_logprobs only includes reqs
|
||||
# that are currently in the prefill phase.
|
||||
self.num_prompt_logprobs: Dict[str, int] = {}
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
@ -238,11 +240,10 @@ class InputBatch:
|
||||
if request.generator is not None:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
num_logprobs = sampling_params.logprobs
|
||||
if num_logprobs is not None and num_logprobs > 0:
|
||||
self.num_logprobs[req_id] = num_logprobs
|
||||
if sampling_params.prompt_logprobs:
|
||||
self.prompt_logprob_reqs.add(req_id)
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
@ -272,7 +273,7 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.prompt_logprob_reqs.discard(req_id)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
@ -297,7 +298,7 @@ class InputBatch:
|
||||
self.repetition_penalties_reqs.clear()
|
||||
self.generators.clear()
|
||||
self.num_logprobs.clear()
|
||||
self.prompt_logprob_reqs.clear()
|
||||
self.num_prompt_logprobs.clear()
|
||||
self.request_lora_mapping.fill(0)
|
||||
self.lora_id_to_lora_request.clear()
|
||||
self.lora_id_to_request_ids.clear()
|
||||
@ -489,13 +490,9 @@ class InputBatch:
|
||||
and len(self.repetition_penalties_reqs) == 0)
|
||||
|
||||
@property
|
||||
def max_num_logprobs(self) -> int:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else 0
|
||||
|
||||
@property
|
||||
def no_logprob(self) -> bool:
|
||||
return len(self.num_logprobs) == 0
|
||||
def max_num_logprobs(self) -> Optional[int]:
|
||||
return max(self.num_logprobs.values()) if self.num_logprobs else None
|
||||
|
||||
@property
|
||||
def no_prompt_logprob(self) -> bool:
|
||||
return len(self.prompt_logprob_reqs) == 0
|
||||
return not self.num_prompt_logprobs
|
||||
|
||||
@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self._prepare_sampling(batch_changed)
|
||||
@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the requests one by one. Optimize.
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
|
||||
for i, req_id in enumerate( # type: ignore[assignment]
|
||||
self.input_batch.req_ids[:num_reqs]):
|
||||
assert req_id is not None
|
||||
req_state = self.requests[req_id]
|
||||
seq_len = (req_state.num_computed_tokens +
|
||||
@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE: GPU -> CPU Sync happens here.
|
||||
# Move as many CPU operations as possible before this sync point.
|
||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
logprobs_lists = logprobs_tensors.tolists() \
|
||||
if logprobs_tensors is not None else None
|
||||
|
||||
# Compute prompt logprobs if needed.
|
||||
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
|
||||
hidden_states,
|
||||
scheduler_output,
|
||||
)
|
||||
|
||||
# Update with the actual token ids
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = sampled_token_ids[i]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
req_state.output_token_ids[-1] = token_id
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
else:
|
||||
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
|
||||
if sampler_output.logprobs is None:
|
||||
logprobs = None
|
||||
else:
|
||||
logprobs = sampler_output.logprobs.cpu()
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprob_token_ids_cpu=logprob_token_ids,
|
||||
logprobs_cpu=logprobs,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
)
|
||||
return model_runner_output
|
||||
|
||||
@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Loading model weights took %.4f GB",
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
def _get_prompt_logprobs_dict(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Dict[str, LogprobsTensors]:
|
||||
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
|
||||
if not num_prompt_logprobs_dict:
|
||||
return {}
|
||||
|
||||
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
|
||||
|
||||
# Since prompt logprobs are a rare feature, prioritize simple,
|
||||
# maintainable loop over optimal performance.
|
||||
completed_prefill_reqs = []
|
||||
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
|
||||
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
|
||||
# Determine number of logits to retrieve.
|
||||
start_tok = request.num_computed_tokens + 1
|
||||
num_remaining_tokens = num_prompt_tokens - start_tok
|
||||
if num_tokens < num_remaining_tokens:
|
||||
# This is a chunk, more tokens remain.
|
||||
num_logits = num_tokens
|
||||
else:
|
||||
# This is the last chunk of prompt tokens to return.
|
||||
num_logits = num_remaining_tokens
|
||||
completed_prefill_reqs.append(req_id)
|
||||
|
||||
# Get the logits corresponding to this req's prompt tokens.
|
||||
# If this is a partial request (i.e. chunked prefill),
|
||||
# then there is prompt logprob generated for each index.
|
||||
req_idx = self.input_batch.req_id_to_index[req_id]
|
||||
offset = self.query_start_loc_np[req_idx].item()
|
||||
prompt_hidden_states = hidden_states[offset:offset + num_logits]
|
||||
logits = self.model.compute_logits(prompt_hidden_states, None)
|
||||
|
||||
# Get the "target" tokens for each index. For prompt at index i,
|
||||
# the token at prompt index i+1 is the "sampled" token we want
|
||||
# to gather the logprob for.
|
||||
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
|
||||
|
||||
# Compute prompt logprobs.
|
||||
logprobs = self.model.sampler.compute_logprobs(logits)
|
||||
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
|
||||
logprobs, num_prompt_logprobs, tgt_token_ids)
|
||||
|
||||
# Transfer GPU->CPU async.
|
||||
prompt_logprobs_dict[req_id] = LogprobsTensors(
|
||||
token_ids.to("cpu", non_blocking=True),
|
||||
logprobs.to("cpu", non_blocking=True),
|
||||
ranks.to("cpu", non_blocking=True),
|
||||
)
|
||||
|
||||
# Remove requests that have completed prefill from the batch
|
||||
# num_prompt_logprobs_dict.
|
||||
for req_id in completed_prefill_reqs:
|
||||
del num_prompt_logprobs_dict[req_id]
|
||||
|
||||
# Must synchronize the non-blocking GPU->CPU transfers.
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prompt_logprobs_dict
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_run(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user