[V1] Logprobs and prompt logprobs support (#9880)

This PR is adding support for sample logprobs & prompt logprobs to vLLM v1.

New behavior:

- During model execution, model runner computes sample logprobs (if user-provided logprobs setting is not None) and prompt logprobs (if user-provided prompt_logprobs setting is not None). For both sample and prompt logprobs, the engine core returns 3 vectors: token ids, token logprob values, token ranks. Ranks reflect tokens' 1-indexed positions in the vocabulary vector after sorting the vocabulary by log probability in descending order.
- In scheduler.update_from_output(), sample and prompt logprobs are incorporated into the EngineCoreOutput data structure which is transferred to the engine client. If multiprocessing is enabled, then sample and prompt logprobs will be (de)serialized when the EngineCoreOutput data structure is (de)serialized.
- During output processing, the LogprobsProcessor transforms the triplet of token ids, token logprobs values, and token ranks into the OpenAI-compatible List[Dict[token id,Logprob]] format (for sample and prompt logprobs respectively.)
- Each Logprob instance (whether sample- or prompt-) consists of a token's log-probability, rank, and detokenized string representation. Note that logprob detokenization is handled by the LogprobsProcessor not the detokenizer.

Signed-off-by: Andrew Feldman <afeldman@neuralmagic.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>


Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
afeldman-nm 2025-02-07 10:26:20 -05:00 committed by GitHub
parent 538fab93cd
commit 0630d4537a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 2865 additions and 283 deletions

View File

@ -195,8 +195,8 @@ def test_schedule_partial_requests():
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[0] * len(requests),
logprob_token_ids_cpu=None,
logprobs_cpu=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output, model_runner_output)

View File

@ -0,0 +1,90 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Tuple
import pytest
import torch
from transformers import AutoTokenizer
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST, PROMPT_LEN,
TOKENIZER_NAME,
DummyOutputProcessorTestVectors,
generate_dummy_prompt_logprobs_tensors,
generate_dummy_sample_logprobs)
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
EngineCoreSampleLogprobsType = List[Tuple[torch.Tensor, torch.Tensor]]
EngineCorePromptLogprobsType = Tuple[torch.Tensor, torch.Tensor]
def _build_test_vectors_no_logprobs() -> DummyOutputProcessorTestVectors:
"""Generate output processor dummy test vectors, without logprobs
Returns:
DummyOutputProcessorTestVectors instance with no logprobs
"""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
vllm_config = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
# Tokenize prompts under test & create dummy generated tokens
prompt_tokens = [
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
]
generation_tokens = [
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
]
# Generate prompt strings
prompt_strings = [
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
for prompt_tokens in prompt_tokens
]
prompt_strings_len = [
len(prompt_string) for prompt_string in prompt_strings
]
return DummyOutputProcessorTestVectors(
tokenizer=tokenizer,
tokenizer_group=init_tokenizer_from_configs(
vllm_config.model_config, vllm_config.scheduler_config,
vllm_config.parallel_config, vllm_config.lora_config),
vllm_config=vllm_config,
full_tokens=[tokenizer(text).input_ids for text in FULL_STRINGS],
prompt_tokens=prompt_tokens,
generation_tokens=generation_tokens,
prompt_strings=prompt_strings,
prompt_strings_len=prompt_strings_len,
generation_strings=[
text[prompt_len:]
for text, prompt_len in zip(FULL_STRINGS, prompt_strings_len)
],
prompt_logprobs=[],
generation_logprobs=[])
@pytest.fixture
def dummy_test_vectors() -> DummyOutputProcessorTestVectors:
"""Generate output processor dummy test vectors, with logprobs
Returns:
DummyOutputProcessorTestVectors instance with logprobs
"""
# Build dummy test vectors without logprobs
dtv = _build_test_vectors_no_logprobs()
# Inject logprobs into dummy test vectors
# data structure
dtv.generation_logprobs = [
generate_dummy_sample_logprobs(
sampled_tokens_list=tokens_list,
num_logprobs=NUM_SAMPLE_LOGPROBS_UNDER_TEST,
tokenizer=dtv.tokenizer) for tokens_list in dtv.generation_tokens
]
dtv.prompt_logprobs = [
generate_dummy_prompt_logprobs_tensors(
prompt_tokens_list=tokens_list,
num_logprobs=NUM_PROMPT_LOGPROBS_UNDER_TEST,
tokenizer=dtv.tokenizer) for tokens_list in dtv.prompt_tokens
]
return dtv

View File

@ -2,10 +2,11 @@
import asyncio
from contextlib import ExitStack
from typing import List, Tuple
from typing import List, Optional, Tuple
import pytest
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.platforms import current_platform
@ -21,13 +22,19 @@ ENGINE_ARGS = AsyncEngineArgs(model="meta-llama/Llama-3.2-1B",
disable_log_requests=True)
async def generate(engine: AsyncLLM, request_id: str,
async def generate(engine: AsyncLLM,
request_id: str,
output_kind: RequestOutputKind,
max_tokens: int) -> Tuple[int, str]:
max_tokens: int,
prompt_logprobs: Optional[int] = None) -> Tuple[int, str]:
# Ensure generate doesn't complete too fast for cancellation test.
await asyncio.sleep(0.2)
count = 0
sampling_params = SamplingParams(max_tokens=max_tokens,
output_kind=output_kind,
temperature=0)
temperature=0,
prompt_logprobs=prompt_logprobs)
async for out in engine.generate(request_id=request_id,
prompt="Hello my name is Robert and",
sampling_params=sampling_params):
@ -43,6 +50,40 @@ async def generate(engine: AsyncLLM, request_id: str,
return count, request_id
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio
async def test_async_llm_refuses_prompt_logprobs_with_apc(
monkeypatch, output_kind: RequestOutputKind):
"""Test passes if AsyncLLM raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
# TODO(rickyx): Remove monkeypatch VLLM_USE_V1 setting once we have a
# better way to test V1 so that in the future when we switch, we don't
# have to change all the tests.
monkeypatch.setenv("VLLM_USE_V1", "1")
# Create AsyncLLM engine with APC
apc_engine_args = AsyncEngineArgs(model="facebook/opt-125m",
enable_prefix_caching=True,
gpu_memory_utilization=0.8,
disable_log_requests=True)
engine = AsyncLLM.from_engine_args(apc_engine_args)
try:
with pytest.raises(ValueError) as excinfo:
# Issue a request with prompt logprobs enabled, which should fail
await asyncio.create_task(
generate(engine,
"request-0",
output_kind,
10,
prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG
finally:
# Shut down engine
engine.shutdown()
@pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.asyncio

View File

@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from tests.v1.engine.utils import PLP_APC_UNSUPPORTED_MSG
from vllm import LLM, SamplingParams
def test_llm_engine_refuses_prompt_logprobs_with_apc(monkeypatch):
"""Test passes if LLMEngine raises an exception when it is configured
for automatic prefix caching and it receives a request with
prompt_logprobs enabled, which is incompatible."""
monkeypatch.setenv("VLLM_USE_V1", "1")
# TODO(nick): Single-proc to work around a ZMQ shutdown hang for now.
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
with pytest.raises(ValueError) as excinfo:
LLM(model="facebook/opt-125m", enable_prefix_caching=True).generate(
"Hello, my name is",
SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=5))
# Validate exception string is correct
assert str(excinfo.value) == PLP_APC_UNSUPPORTED_MSG

View File

@ -1,82 +1,47 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List
import math
from typing import Dict, List, Optional
import pytest
from transformers import AutoTokenizer
from vllm.engine.arg_utils import EngineArgs
from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
NUM_SAMPLE_LOGPROBS_UNDER_TEST,
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore)
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
VLLM_CONFIG = EngineArgs(model=TOKENIZER_NAME).create_engine_config()
TOKENIZER_GROUP = init_tokenizer_from_configs(VLLM_CONFIG.model_config,
VLLM_CONFIG.scheduler_config,
VLLM_CONFIG.parallel_config,
VLLM_CONFIG.lora_config)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
FULL_STRINGS = [
"My name is Robert from Neural Magic and I love working on vLLM so much!",
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
"Nick is the name of my brother in addition to my colleague from Red Hat.",
]
def _ref_convert_id_to_token(
tokenizer: AnyTokenizer,
token_id: int,
) -> str:
"""Reference impl of logprobs detokenization.
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
Args:
tokenizer: tokenizer used by the model under test
token_id: convert this token id
FULL_TOKENS = [tokenizer(text).input_ids for text in FULL_STRINGS]
PROMPT_LEN = 5
PROMPT_TOKENS = [
tokenizer(text).input_ids[:PROMPT_LEN] for text in FULL_STRINGS
]
GENERATION_TOKENS = [
tokenizer(text).input_ids[PROMPT_LEN:] for text in FULL_STRINGS
]
PROMPT_STRINGS = [
tokenizer.decode(prompt_tokens, skip_special_tokens=True)
for prompt_tokens in PROMPT_TOKENS
]
PROMPT_STRINGS_LEN = [len(prompt_string) for prompt_string in PROMPT_STRINGS]
GENERATION_STRINGS = [
text[prompt_len:]
for text, prompt_len in zip(FULL_STRINGS, PROMPT_STRINGS_LEN)
]
class MockEngineCore:
"""Mock outputs form premade tokens lists."""
def __init__(self, tokens_list: List[List[int]]):
self.tokens_list = tokens_list
self.current_idx = 0
def get_outputs(self) -> List[EngineCoreOutput]:
token_idx = self.current_idx
self.current_idx += 1
outputs = []
for req_idx, token_ids in enumerate(self.tokens_list):
if len(token_ids) > token_idx:
output = EngineCoreOutput(request_id=f"request-{req_idx}",
new_token_ids=[token_ids[token_idx]],
finished=False)
if token_idx == len(token_ids) - 1:
output.finished = True
output.finish_reason = "stopped"
outputs.append(output)
return outputs
Returns:
String representation of input token id
"""
return tokenizer.convert_ids_to_tokens(token_id) or ""
@pytest.mark.parametrize(
"request_output_kind",
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
def test_incremental_detokenization(request_output_kind: RequestOutputKind):
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
engine_core = MockEngineCore(GENERATION_TOKENS)
def test_incremental_detokenization(request_output_kind: RequestOutputKind,
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens)
# Make N requests.
requests = [
@ -94,10 +59,10 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False))
for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
include_stop_str_in_output=False,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
]
# Add requests to the detokenizer.
@ -113,7 +78,7 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
break
# Step the Detokenizer.
processed_outputs = output_processor.process_outputs(outputs, )
processed_outputs = output_processor.process_outputs(outputs)
request_outputs = processed_outputs.request_outputs
requests_to_abort = processed_outputs.reqs_to_abort
assert len(requests_to_abort) == 0
@ -132,7 +97,8 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str, ref_gen_toks) in enumerate(
zip(GENERATION_STRINGS, GENERATION_TOKENS)):
zip(dummy_test_vectors.generation_strings,
dummy_test_vectors.generation_tokens)):
gen_str = gen_strings[f"request-{idx}"]
gen_toks = gen_tokens[f"request-{idx}"]
@ -143,15 +109,390 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
assert not output_processor.has_unfinished_requests()
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
def test_stop_string(include_stop_str_in_output: bool):
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=False)
engine_core = MockEngineCore(GENERATION_TOKENS)
def _validate_logprobs(
gen_tokens: Dict[str, List[int]],
gen_logprobs: Dict[str, Optional[SampleLogprobs]],
gen_prompt_logprobs: Dict[str, Optional[PromptLogprobs]],
gen_cumulative_logprob: Dict[str, float],
dtv: DummyOutputProcessorTestVectors,
request_id_list: List[str],
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
) -> None:
for req_idx, req_id in enumerate(request_id_list):
new_tokens = gen_tokens[req_id]
logprobs = gen_logprobs[req_id]
prompt_logprobs = gen_prompt_logprobs[req_id]
cumulative_logprob = gen_cumulative_logprob[req_id]
prompt_token_ids = dtv.prompt_tokens[req_idx]
ref_logprobs = dtv.generation_logprobs[req_idx]
ref_prompt_logprobs = dtv.prompt_logprobs[req_idx]
if num_sample_logprobs is not None:
# Validate sample logprobs
assert logprobs is not None, (f"Request {req_id} requires sample"
" logprobs but sample logprobs are"
" None.")
# Require num sampled tokens to match num
# sampled logprobs - especially important
# to check since the detokenizer can cause
# a request to finish early due to a stop
# string being hit
num_new_tokens = len(new_tokens)
len_sample_logprobs = len(logprobs)
assert num_new_tokens == len_sample_logprobs, (
f"Request {req_id} has {num_new_tokens}"
" completion tokens but has"
f" {len_sample_logprobs} sample logprobs.")
ref_cumulative_logprob = 0.0
for idx, (sampled_token,
pos_logprob_dict) in enumerate(zip(new_tokens,
logprobs)):
# Break out the reference log probability value &
# logprob token id tensors associated with this
# position in the completion. Also break out the
# sampled token ranks
(ref_pos_logprob_toks, ref_pos_logprob_vals,
ref_sampled_token_rank) = ref_logprobs[idx]
# For each position in the completion sequence,
# ensure the actual sampled token is among the
# logprobs
assert sampled_token in pos_logprob_dict, (
f"Sampled token {sampled_token} not"
f" present in logprob at index {idx}")
# Validate number of sample logprobs
num_lp_toks = len(pos_logprob_dict)
assert (num_lp_toks == num_sample_logprobs
or num_lp_toks == num_sample_logprobs +
1), ("Valid numbers of sample logprobs are"
f" {num_sample_logprobs} or"
f" {num_sample_logprobs+1} but"
f" {num_lp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}")
# Validate sampled token logprob rank
smp_lp = pos_logprob_dict[sampled_token]
smp_lp_rank = smp_lp.rank
assert (ref_sampled_token_rank == smp_lp_rank), (
"Sampled token logprob rank"
f" {smp_lp_rank} does not match"
" correct value"
f" {ref_sampled_token_rank}"
f" in Logprob {smp_lp}")
# Validate that the logprob processor yields
# the correct log probabilities and valid
# rankings
rank_one_appears = False
for jdx in range(1, len(ref_pos_logprob_toks)):
# Iterate over the (logprob val,logprob tok id)
# pairs expected by the test fixture at this
# position in the completion.
ref_lp_val = ref_pos_logprob_vals[jdx]
ref_tok_id = ref_pos_logprob_toks[jdx]
assert ref_tok_id in pos_logprob_dict, (
f"Expected token {ref_tok_id} to be"
f" in logprob dict but it is not.")
# Extract actually-generated logprob
# info
lp = pos_logprob_dict[ref_tok_id]
lp_val = lp.logprob
lp_rank = lp.rank
# A "top" (rank 1) logprob must be
# present
rank_one_appears = (True
if lp_rank == 1 else rank_one_appears)
# Rank must be >= 1
assert lp_rank >= 1, (f"Logprob {lp} has invalid"
f" rank {lp_rank} < 1."
f" Logprob dict: {pos_logprob_dict}")
# Validate log probability
assert math.isclose(lp_val, ref_lp_val), (
f"Token id {ref_tok_id} appears in logprobs dict"
f" at position {idx} in completion with log"
f" probability {lp_val} but {ref_lp_val} was"
f" expected. Logprob: {lp}")
assert rank_one_appears, (f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}")
# Validate logprobs detokenization
for lp_tok in pos_logprob_dict:
# Confirm that sample logprob decoded token matches
# the logprob token id at this sequence position
decoded_token = pos_logprob_dict[lp_tok].decoded_token
ref_decoded_token = _ref_convert_id_to_token(
dtv.tokenizer, lp_tok)
assert decoded_token == ref_decoded_token, (
f"Sampled logprob token id {lp_tok} decodes to"
f" {ref_decoded_token} but Logprob decoded"
f" token is {decoded_token} instead"
f" (at position {idx})")
ref_cumulative_logprob += pos_logprob_dict[
sampled_token].logprob
# Assert that cumulative logprobs are correct
assert math.isclose(cumulative_logprob, ref_cumulative_logprob)
else:
# Sample logprobs disabled for this request
assert logprobs is None
assert cumulative_logprob is None
if num_prompt_logprobs is not None:
# Validate prompt logprobs
assert prompt_logprobs is not None, (
f"Request {req_id} requires prompt"
" logprobs but prompt logprobs are"
" None.")
# Require num prompt tokens to match num
# prompt logprobs
num_prompt_tokens = len(prompt_token_ids)
len_prompt_logprobs = len(prompt_logprobs)
assert num_prompt_tokens == len_prompt_logprobs, (
f"Request {req_id} has {num_prompt_tokens}"
" prompt tokens but has"
f" {len_prompt_logprobs} prompt logprobs.")
# First prompt logprob is None
first_plp_dict = prompt_logprobs[0]
assert first_plp_dict is None, (
f"Request {req_id} first prompt logprob"
f" should be None but has following value"
f" instead: {first_plp_dict}")
# Break out the reference prompt log prob value &
# logprob token id matrices for the whole prompt.
# Also break out the prompt token rank vector
(ref_prompt_logprob_toks, ref_prompt_logprob_vals,
ref_prompt_token_ranks) = ref_prompt_logprobs
for idx, (prompt_token, pos_logprob_dict) in enumerate(
zip(prompt_token_ids[1:], prompt_logprobs[1:])):
# Break out the reference prompt log prob value
# vector, prompt logprob token id vector, and
# prompt token rank at the current position.
(ref_pos_prompt_logprob_toks, ref_pos_prompt_logprob_vals,
ref_pos_prompt_token_rank) = (ref_prompt_logprob_toks[idx, :],
ref_prompt_logprob_vals[idx, :],
ref_prompt_token_ranks[idx])
# For each position in the prompt sequence,
# ensure the actual prompt token is among the
# logprobs
assert prompt_token in pos_logprob_dict, (
f"Prompt token {prompt_token} not"
f" present in logprob at index {idx}")
# Validate number of prompt logprobs
num_plp_toks = len(pos_logprob_dict)
assert (num_plp_toks == num_prompt_logprobs
or num_plp_toks == num_prompt_logprobs +
1), ("Valid numbers of prompt logprobs are"
f" {num_prompt_logprobs} or"
f" {num_prompt_logprobs+1} but"
f" {num_plp_toks} logprobs found at"
f" position {idx}. Logprobs dict:"
f" {pos_logprob_dict}")
# Validate prompt token logprob rank
prmpt_tok_lp = pos_logprob_dict[prompt_token]
prmpt_tok_lp_rank = prmpt_tok_lp.rank
ref_prmpt_tok_lp_rank = ref_pos_prompt_token_rank
assert (ref_prmpt_tok_lp_rank == prmpt_tok_lp_rank), (
"Prompt token logprob rank"
f" {prmpt_tok_lp_rank} does not match"
" correct value"
f" {ref_prmpt_tok_lp_rank}"
f" in Logprob {prmpt_tok_lp}")
# Validate that the logprob processor yields
# the correct prompt log probs and valid
# rankings
rank_one_appears = False
for jdx in range(1, len(ref_pos_prompt_logprob_toks)):
# Iterate over the (logprob val,logprob tok id)
# pairs expected by the test fixture at this
# position in the completion.
ref_plp_val = float(ref_pos_prompt_logprob_vals[jdx])
ref_tok_id = int(ref_pos_prompt_logprob_toks[jdx])
assert ref_tok_id in pos_logprob_dict, (
f"Expected token {ref_tok_id} to be"
f" in logprob dict but it is not.")
# Extract actually-generated logprob
# info
plp = pos_logprob_dict[ref_tok_id]
plp_val = plp.logprob
plp_rank = plp.rank
# A "top" (rank 1) logprob must be
# present
rank_one_appears = (True
if plp_rank == 1 else rank_one_appears)
# Rank must be >= 1
assert plp_rank >= 1, (
f"Logprob {plp} has invalid"
f" rank {plp_rank} < 1."
f" Logprob dict: {pos_logprob_dict}")
# Validate log probability
assert math.isclose(plp_val, ref_plp_val), (
f"Token id {ref_tok_id} appears in logprobs dict"
f" at position {idx} in completion with log"
f" probability {plp_val} but {ref_plp_val} was"
f" expected. Logprob: {plp}")
assert rank_one_appears, (f"No Logprob has rank 1"
" in the following Logprob"
f" dict: {pos_logprob_dict}")
# Validate prompt logprob detokenization
for plp_tok in pos_logprob_dict:
# Confirm that prompt logprob decoded token matches
# the logprob token id at this sequence position
decoded_token = pos_logprob_dict[plp_tok].decoded_token
ref_decoded_token = _ref_convert_id_to_token(
dtv.tokenizer, plp_tok)
assert decoded_token == ref_decoded_token, (
f"Prompt logprob token id {plp_tok} decodes to"
f" {ref_decoded_token} but Logprob decoded"
f" token is {decoded_token} instead"
f" (at position {idx})")
else:
# Prompt logprobs disabled for this request
assert prompt_logprobs is None
@pytest.mark.parametrize(
"request_output_kind",
[RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("num_sample_logprobs",
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs",
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_logprobs_processor(request_output_kind: RequestOutputKind,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int],
dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=None if num_sample_logprobs is None else
dummy_test_vectors.generation_logprobs,
prompt_logprobs_raw=None
if num_prompt_logprobs is None else dummy_test_vectors.prompt_logprobs)
# Make N requests.
request_id_list = [
f"request-{idx}"
for idx in range(len(dummy_test_vectors.prompt_strings))
]
requests = [
EngineCoreRequest(request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False,
spaces_between_special_tokens=False,
output_kind=request_output_kind,
stop=[],
include_stop_str_in_output=False,
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
]
# Add requests to the detokenizer.
for request in requests:
output_processor.add_request(request)
gen_tokens = {}
gen_logprobs = {}
gen_prompt_logprobs = {}
gen_cumulative_logprobs = {}
while True:
# Mock output from the EngineCore.
outputs = engine_core.get_outputs()
if len(outputs) == 0:
break
# Step the logprobs processor.
processed_outputs = output_processor.process_outputs(outputs)
request_outputs = processed_outputs.request_outputs
requests_to_abort = processed_outputs.reqs_to_abort
assert len(requests_to_abort) == 0
# Update tracking.
for request_output in request_outputs:
request_id = request_output.request_id
new_tokens = request_output.outputs[0].token_ids
prompt_logprobs = request_output.prompt_logprobs
logprobs = request_output.outputs[0].logprobs
gen_cumulative_logprobs[request_id] = request_output.outputs[
0].cumulative_logprob
if request_id not in gen_logprobs:
# Start tracking sample and prompt logprobs for this request
gen_tokens[request_id] = new_tokens
gen_logprobs[request_id] = logprobs
gen_prompt_logprobs[request_id] = prompt_logprobs
else:
# Extend logprobs tracker
gen_tokens[request_id].extend(new_tokens)
lp = gen_logprobs[request_id]
plp = gen_prompt_logprobs[request_id]
if lp:
lp.extend(logprobs)
if plp:
plp.extend(prompt_logprobs)
# Confirmed tracked logprobs match what we expect
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
gen_cumulative_logprobs, dummy_test_vectors,
request_id_list, num_sample_logprobs,
num_prompt_logprobs)
assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.parametrize("num_sample_logprobs",
[None, NUM_SAMPLE_LOGPROBS_UNDER_TEST])
@pytest.mark.parametrize("num_prompt_logprobs",
[None, NUM_PROMPT_LOGPROBS_UNDER_TEST])
def test_stop_string(include_stop_str_in_output: bool,
num_sample_logprobs: Optional[int],
num_prompt_logprobs: Optional[int], dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=False)
engine_core = MockEngineCore(
tokens_list=dummy_test_vectors.generation_tokens,
generated_logprobs_raw=dummy_test_vectors.generation_logprobs
if num_sample_logprobs else None,
prompt_logprobs_raw=dummy_test_vectors.prompt_logprobs
if num_prompt_logprobs else None)
# Make N requests.
request_id_list = [
f"request-{idx}"
for idx in range(len(dummy_test_vectors.prompt_strings))
]
requests = [
EngineCoreRequest(
request_id=f"request-{idx}",
request_id=request_id_list[idx],
prompt=prompt,
prompt_token_ids=prompt_tokens,
arrival_time=0,
@ -166,9 +507,11 @@ def test_stop_string(include_stop_str_in_output: bool):
output_kind=RequestOutputKind.DELTA,
stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output,
)) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
logprobs=num_sample_logprobs,
prompt_logprobs=num_prompt_logprobs,
)) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
]
# Add requests to the detokenizer.
@ -176,6 +519,10 @@ def test_stop_string(include_stop_str_in_output: bool):
output_processor.add_request(request)
gen_strings = {}
gen_tokens = {}
gen_logprobs = {}
gen_prompt_logprobs = {}
gen_cumulative_logprobs = {}
aborted = []
while True:
# Mock output from the EngineCore.
@ -199,14 +546,29 @@ def test_stop_string(include_stop_str_in_output: bool):
request_id = request_output.request_id
new_text = request_output.outputs[0].text
new_tokens = request_output.outputs[0].token_ids
prompt_logprobs = request_output.prompt_logprobs
logprobs = request_output.outputs[0].logprobs
gen_cumulative_logprobs[request_id] = request_output.outputs[
0].cumulative_logprob
if request_id not in gen_strings:
gen_strings[request_id] = new_text
gen_tokens[request_id] = new_tokens
gen_logprobs[request_id] = logprobs
gen_prompt_logprobs[request_id] = prompt_logprobs
else:
gen_strings[request_id] += new_text
gen_tokens[request_id].extend(new_tokens)
lp = gen_logprobs[request_id]
plp = gen_prompt_logprobs[request_id]
if lp:
lp.extend(logprobs)
if plp:
plp.extend(prompt_logprobs)
# Confirmed tracked values matches what we expected.
for idx, (ref_gen_str,
stop_str) in enumerate(zip(GENERATION_STRINGS, STOP_STRINGS)):
for idx, (ref_gen_str, stop_str) in enumerate(
zip(dummy_test_vectors.generation_strings, STOP_STRINGS)):
# Request should be aborted.
request_id = f"request-{idx}"
@ -227,13 +589,20 @@ def test_stop_string(include_stop_str_in_output: bool):
assert gen_str == ref_str_exc_stop, (
f"{gen_str=}, {ref_str_exc_stop=}")
# Confirmed tracked logprobs match what we expect
_validate_logprobs(gen_tokens, gen_logprobs, gen_prompt_logprobs,
gen_cumulative_logprobs, dummy_test_vectors,
request_id_list, num_sample_logprobs,
num_prompt_logprobs)
assert output_processor.get_num_unfinished_requests() == 0
assert not output_processor.has_unfinished_requests()
def test_iteration_stats():
output_processor = OutputProcessor(TOKENIZER_GROUP, log_stats=True)
engine_core = MockEngineCore(GENERATION_TOKENS)
def test_iteration_stats(dummy_test_vectors):
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
log_stats=True)
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
# Make N requests.
requests = [
@ -248,13 +617,13 @@ def test_iteration_stats():
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(),
) for idx, (
prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
) for idx, (prompt, prompt_tokens) in enumerate(
zip(dummy_test_vectors.prompt_strings,
dummy_test_vectors.prompt_tokens))
]
# Add all requests except one to the OutputProcessor.
num_active = len(GENERATION_TOKENS) - 1
num_active = len(dummy_test_vectors.generation_tokens) - 1
for request in requests[:num_active]:
output_processor.add_request(request)
inactive_request = requests[num_active]
@ -263,8 +632,10 @@ def test_iteration_stats():
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
total_prompt_tokens = sum(
[len(prompt_tokens) for prompt_tokens in PROMPT_TOKENS[:num_active]])
total_prompt_tokens = sum([
len(prompt_tokens)
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
])
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
assert iteration_stats.num_generation_tokens == num_active
@ -283,7 +654,7 @@ def test_iteration_stats():
outputs = engine_core.get_outputs()[:num_active]
processed_outputs = output_processor.process_outputs(outputs)
iteration_stats = processed_outputs.iteration_stats
total_prompt_tokens = len(PROMPT_TOKENS[num_active - 1])
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
assert iteration_stats.num_generation_tokens == num_active

382
tests/v1/engine/utils.py Normal file
View File

@ -0,0 +1,382 @@
# SPDX-License-Identifier: Apache-2.0
import random
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import torch
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.v1.engine import EngineCoreOutput, FinishReason
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
GeneralTokenizerType = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
# Number of sample logprobs to request when testing sample logprobs
NUM_SAMPLE_LOGPROBS_UNDER_TEST = 5
# Number of prompt logprobs to request when testing prompt logprobs
NUM_PROMPT_LOGPROBS_UNDER_TEST = 7
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
FULL_STRINGS = [
"My name is Robert from Neural Magic and I love working on vLLM so much!",
"Red Hat is the best open source company by far across Linux, K8s, and AI.",
"Nick is the name of my brother in addition to my colleague from Red Hat.",
]
STOP_STRINGS = ["I love working on", "company by far", "brother in"]
PROMPT_LEN = 5
PLP_APC_UNSUPPORTED_MSG = ("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
random.seed(42)
def _create_random_top_logprob_test_vector(
num_logprobs: int,
lower: float,
upper: float,
) -> torch.Tensor:
"""Create a random vector of top logprob float values.
Use to create fake sample logprobs for testing.
Note that a real production scenario would require
logprobs to be sorted in descending order, something
which is omitted in this function.
Args:
num_logprobs: number of top logprobs
lower: lower range of logprob float values
upper: upper range of logprob float values
Returns:
1D length-`num_logprobs` torch Tensor of float logprob values
"""
return torch.rand(num_logprobs) * (upper - lower) + lower
def _create_random_top_logprob_test_matrix(
shape: Tuple,
lower: float,
upper: float,
) -> torch.Tensor:
"""Create a random matrix of top logprob float values.
Use to create fake prompt logprobs for testing.
Note that a real production scenario would require
logprobs to be sorted in descending order along rows,
something which is omitted in this function.
Args:
shape: (num_tokens,num_logprobs) tuple representing
matrix shape
lower: lower range of logprob float values
upper: upper range of logprob float values
Returns:
2D num_tokens x num_logprobs torch Tensor of float logprob values
"""
return torch.rand(*shape) * (upper - lower) + lower
def _create_random_top_token_test_vector(
num_logprobs: int,
lower: int,
upper: int,
sampled_token_id: int,
adjust_num_logprobs: bool = True) -> Tuple[torch.Tensor, int]:
"""Create a random vector of top logprob token indices
Use to create fake sample logprobs for testing. The sampled token
ID must always be one of the top logprobs, which this dummy test
vector generator enforces. OpenAI API
compatible engines must be able to return an additional sample
logprob for the sampled token if the sampled token was not
among the top sample logprobs; `adjust_num_logprobs` emulates
this behavior by increasing the vector length by 1 if
`adjust_num_logprobs` is set.
Args:
num_logprobs: number of top logprobs
lower: lower range of token ids
upper: upper range of token ids
sampled_token_id: the token actually sampled
adjust_num_logprobs: if True, emulate situation where sampled
token logprob must be injected into top
logprobs
Returns:
1D length-x torch Tensor of token ids where x is
`num_logprobs+1` if `adjust_num_logprobs` and
`num_logprobs` otherwise
sampled_token_rank: the rank of sampled_token_id in the vocab
vector when sorted in descending order by
logprob
"""
# Calculate the final number of logprobs required
total_logprobs = num_logprobs + 1 if adjust_num_logprobs else num_logprobs
# Generate random indices using torch
choice_tensor = torch.randperm(upper - lower)[:total_logprobs] + lower
# Ensure the sampled token ID is included in the tensor
choice_tensor[0] = sampled_token_id
# Check if the sampled_token_id occurs in choice_tensor[1:]
if sampled_token_id in choice_tensor[1:]:
sampled_token_rank = (choice_tensor[1:] == sampled_token_id).nonzero(
as_tuple=True)[0].item()
else:
# If not found, assign a random int between num_logprobs and 50700
sampled_token_rank = random.randint(num_logprobs, 50700)
return choice_tensor, sampled_token_rank
def _create_random_top_token_test_matrix(
shape: Tuple[int, int],
lower: int,
upper: int,
tokens_list: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Create a random matrix of top logprob token indices
Use to create fake prompt logprobs for testing.
Token ids are generated randomly and sampled without
replacement.
Args:
shape: (num_tokens, num_logprobs) tuple representing
matrix shape
lower: lower range of token ids
upper: upper range of token ids
Returns:
Tuple containing:
- 2D num_tokens x num_logprobs+1 torch Tensor of token ids
- 1D tensor of ranks of prompt tokens in their respective
rows, or random values
"""
num_elements = shape[0] * shape[1]
choice_tensor = torch.randperm(upper - lower)[:num_elements] + lower
matrix = torch.cat(
(torch.tensor(tokens_list, dtype=torch.int).unsqueeze(-1),
choice_tensor.view(shape)),
dim=1)
# Initialize the tensor for storing the ranks
prompt_token_ranks = torch.empty(shape[0], dtype=torch.int)
# Iterate over each row to check presence of
# tokens_list[rdx] and determine its index
for rdx in range(shape[0]):
row = matrix[rdx,
1:] # Skip the first column as it contains the token list
token_index = (row == tokens_list[rdx]).nonzero(as_tuple=True)[0]
if token_index.numel() > 0:
prompt_token_ranks[rdx] = token_index.item()
else:
prompt_token_ranks[rdx] = random.randint(shape[1], 50700)
return matrix, prompt_token_ranks
def decode_token(
tok_id: int,
tokenizer: PreTrainedTokenizer,
) -> str:
"""Reproduce the process of detokenizing a token for testing purposes.
Args:
tok_id: token id to detokenize
tokenizer: tokenizer to use for detokenization
Returns:
string representation of token
"""
return tokenizer.convert_ids_to_tokens(tok_id)
def generate_dummy_sample_logprobs(
sampled_tokens_list: List,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> List[Tuple[List[int], List[float], int]]:
"""Generate dummy sample logprobs
Generate a test data structure which imitates the list of sample logprobs
which would be assembled in the engine core during decode phase.
Args:
sampled_tokens_list: list of sampled tokens
num_logprobs: return `num_logprobs` or `num_logprobs+1` logprobs per token
tokenizer: model tokenizer to use for detokenization
Returns
List of (top token ids vector, logprobs vector, sampled token rank)
Python lists tuples; in each tuple the logprobs and top token ids
vectors have the same length which is either `num_logprobs` or
`num_logprobs+1`. Sampled token rank is the rank (index+1) of the
sampled token within the vocab vector when sorted by logprob in
descending order.
"""
res = []
for sampled_token_id in sampled_tokens_list:
(
token_vector,
sampled_token_rank,
) = _create_random_top_token_test_vector(num_logprobs, 0,
len(tokenizer.vocab) - 1,
sampled_token_id)
res.append(
(token_vector,
_create_random_top_logprob_test_vector(num_logprobs + 1, -100,
0), sampled_token_rank))
# Convert tensors in the list tuples to Python lists
res_list_format = [
(log_probs_tensor.tolist(), token_ids_tensor.tolist(),
sampled_token_rank)
for log_probs_tensor, token_ids_tensor, sampled_token_rank in res
]
return res_list_format
def generate_dummy_prompt_logprobs_tensors(
prompt_tokens_list: List,
num_logprobs: int,
tokenizer: PreTrainedTokenizer,
) -> LogprobsTensors:
"""Generate dummy prompt logprobs tensors
Generate a test data structure which imitates the torch Tensors of prompt
logprobs which would be assembled in the engine core during chunked
prefill.
Args:
prompt_tokens_list: list of prompt tokens
num_logprobs: return `num_logprobs` logprobs per token
tokenizer: model tokenizer to use for detokenization
Returns
Single Tuple of (logprobs matrix, top token ids matrix) torch Tensor,
where both matrices have dimensions
num_prompt_tokens x num_logprobs
"""
# For now, assume the whole prompt is processed in one chunk; thus,
# the number of non-`None` prompt logprobs is `len(prompt_tokens_list)-1`.
# Prior to injecting `None` at the beginning of prompt logprobs (which
# happens later in the detokenizer, not here), the prompt logprobs in
# the ith position are predicting the probability distribution of the
# prompt token in (i+1)st position. Thus, we concat
# `prompt_tokens_list[1:]` to the dummy token ids, just as the engine
# would.
num_prompt_logprobs = len(prompt_tokens_list) - 1
(
token_vector,
prompt_token_ranks,
) = _create_random_top_token_test_matrix(
(num_prompt_logprobs, num_logprobs), 0,
len(tokenizer.vocab) - 1, prompt_tokens_list[1:])
return LogprobsTensors(
token_vector,
_create_random_top_logprob_test_matrix(
(num_prompt_logprobs, num_logprobs + 1), -100, 0),
prompt_token_ranks)
@dataclass
class DummyOutputProcessorTestVectors:
"""Dummy test vectors for output processor tests"""
tokenizer: GeneralTokenizerType
tokenizer_group: BaseTokenizerGroup
vllm_config: EngineArgs
full_tokens: List[List[int]] # Prompt + generated tokens
prompt_tokens: List[List[int]]
generation_tokens: List[List[int]]
# Each request is associated with a tuple of
# (top tokens, top logprobs, ranks) prompt logprobs tensors
prompt_logprobs: List[LogprobsTensors]
# Each request is associated with a sample logprobs; a request's
# sample logprobs are a list of (top tokens, top logprobs, ranks)
# sample logprobs tensors at each sequence position
generation_logprobs: List[List[Tuple[List[int], List[float], int]]]
prompt_strings: List[str]
prompt_strings_len: List[int]
generation_strings: List[str]
class MockEngineCore:
"""Mock engine core outputs form premade tokens lists."""
def __init__(
self,
tokens_list: List[List[int]],
# For each request, for each sampled token offset,
# a tuple of
# (list of topk token ids, list of sample logprob vals, rank)
generated_logprobs_raw: Optional[List[List[Tuple[List[int],
List[float],
int]]]] = None,
# For each request, a tuple of
# (prompt logprob val matrix, prompt logprob tok id matrix);
# each matrix has dimensions
# (num prompt toks) x (num prompt logprobs+1)
prompt_logprobs_raw: Optional[List[LogprobsTensors]] = None,
) -> None:
self.tokens_list = tokens_list
self.current_idx = 0
self.generated_logprobs_raw = generated_logprobs_raw
self.do_logprobs = generated_logprobs_raw is not None
self.prompt_logprobs_raw = prompt_logprobs_raw
self.do_prompt_logprobs = prompt_logprobs_raw is not None
def get_outputs(self) -> List[EngineCoreOutput]:
do_logprobs = self.do_logprobs
do_prompt_logprobs = self.do_prompt_logprobs
token_idx = self.current_idx
outputs = []
for req_idx, token_ids in enumerate(self.tokens_list):
if len(token_ids) > token_idx:
if do_logprobs:
assert self.generated_logprobs_raw is not None
(logprobs_token_ids_, logprobs_, sampled_token_ranks_) = (
self.generated_logprobs_raw[req_idx][token_idx])
logprobs = LogprobsLists(
[logprobs_token_ids_],
[logprobs_],
[sampled_token_ranks_],
)
else:
logprobs = None
if do_prompt_logprobs:
if self.current_idx == 0:
assert self.prompt_logprobs_raw is not None
prompt_logprobs = self.prompt_logprobs_raw[req_idx]
else:
prompt_logprobs = None
else:
prompt_logprobs = None
output = EngineCoreOutput(
request_id=f"request-{req_idx}",
new_token_ids=[token_ids[token_idx]],
new_logprobs=logprobs,
new_prompt_logprobs_tensors=prompt_logprobs,
)
if token_idx == len(token_ids) - 1:
output.finish_reason = FinishReason.STOP
outputs.append(output)
self.current_idx += 1
return outputs

View File

View File

@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
@pytest.fixture
def sample_prompts():
return [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
@pytest.fixture
def sample_token_ids():
return [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
@pytest.fixture
def sample_regex():
return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
@pytest.fixture
def sample_json_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"skills": {
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
},
"work_history": {
"type": "array",
"items": {
"type": "object",
"properties": {
"company": {
"type": "string"
},
"duration": {
"type": "number"
},
"position": {
"type": "string"
}
},
"required": ["company", "position"]
}
}
},
"required": ["name", "age", "skills", "work_history"]
}
@pytest.fixture
def sample_complex_json_schema():
return {
"type": "object",
"properties": {
"score": {
"type": "integer",
"minimum": 0,
"maximum": 100 # Numeric range
},
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
},
"tags": {
"type": "array",
"items": {
"type": "string",
"pattern":
"^[a-z]{1,10}$" # Combining length and pattern restrictions
}
}
},
"required": ["score", "grade", "email", "tags"]
}
@pytest.fixture
def sample_definition_json_schema():
return {
'$defs': {
'Step': {
'properties': {
'explanation': {
'title': 'Explanation',
'type': 'string'
},
'output': {
'title': 'Output',
'type': 'string'
}
},
'required': ['explanation', 'output'],
'title': 'Step',
'type': 'object'
}
},
'properties': {
'steps': {
'items': {
'$ref': '#/$defs/Step'
},
'title': 'Steps',
'type': 'array'
},
'final_answer': {
'title': 'Final Answer',
'type': 'string'
}
},
'required': ['steps', 'final_answer'],
'title': 'MathReasoning',
'type': 'object'
}
@pytest.fixture
def sample_guided_choice():
return [
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript",
"Ruby", "Swift", "Kotlin"
]
@pytest.fixture
def sample_sql_statements():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition
column: "col_1" | "col_2"
table: "table_1" | "table_2"
condition: column "=" number
number: "1" | "2"
""")

View File

@ -0,0 +1,475 @@
# SPDX-License-Identifier: Apache-2.0
import re
from typing import Dict, List, Optional
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
from openai import BadRequestError
from tests.utils import RemoteOpenAIServer
from vllm.transformers_utils.tokenizer import get_tokenizer
# any model with a chat template should work here
MODEL_NAME = "facebook/opt-125m"
@pytest.fixture(scope="module")
def default_server_args():
return [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--max-model-len",
"2048",
"--max-num-seqs",
"128",
"--enforce-eager"
]
@pytest.fixture(scope="module",
params=[["--no-enable-prefix-caching"],
[
"--no-enable-prefix-caching",
"--disable-frontend-multiprocessing"
]])
def server(default_server_args, request):
if request.param:
default_server_args.extend(request.param)
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client(server):
async with server.get_async_client() as async_client:
yield async_client
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_single_completion(client: openai.AsyncOpenAI,
model_name: str) -> None:
completion = await client.completions.create(model=model_name,
prompt="Hello, my name is",
max_tokens=5,
temperature=0.0)
assert completion.id is not None
assert completion.choices is not None and len(completion.choices) == 1
choice = completion.choices[0]
assert len(choice.text) >= 5
assert choice.finish_reason == "length"
assert completion.usage == openai.types.CompletionUsage(
completion_tokens=5, prompt_tokens=6, total_tokens=11)
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 1
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=0,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
# test using token IDs
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
model_name: str) -> None:
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=21,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=30,
stream=True,
)
async for chunk in stream:
...
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
assert len(completion.choices[0].text) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.completions.create(**params)
else:
completion = await client.completions.create(**params)
if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0
assert completion.choices[1].prompt_logprobs is not None
assert len(completion.choices[1].prompt_logprobs) > 0
else:
assert completion.choices[0].prompt_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
model_name: str) -> None:
prompt = "What is an LLM?"
single_completion = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
)
single_output = single_completion.choices[0].text
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True)
chunks: List[str] = []
finish_reason_count = 0
async for chunk in stream:
chunks.append(chunk.choices[0].text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
# finish reason should only return in last block
assert finish_reason_count == 1
assert chunk.choices[0].finish_reason == "length"
assert chunk.choices[0].text
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
False,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
False,
})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
else:
assert chunk.usage is None
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
if chunk.choices[0].finish_reason is not None:
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options=
# {"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": None})
# Test stream=False, stream_options=
# {"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"include_usage": True})
# Test stream=False, stream_options=
# {"continuous_usage_stats": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": None})
# Test stream=False, stream_options=
# {"continuous_usage_stats": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": True})
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
# test both text and token IDs
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
# test simple list
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
)
assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text
# test n = 2
batch = await client.completions.create(
model=model_name,
prompt=prompts,
n=2,
max_tokens=5,
temperature=0.0,
extra_body=dict(
# NOTE: this has to be true for n > 1 in vLLM, but
# not necessary for official client.
use_beam_search=True),
)
assert len(batch.choices) == 4
assert batch.choices[0].text != batch.choices[
1].text, "beam search should be different"
assert batch.choices[0].text == batch.choices[
2].text, "two copies of the same prompt should be the same"
assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
# test streaming
batch = await client.completions.create(
model=model_name,
prompt=prompts,
max_tokens=5,
temperature=0.0,
stream=True,
)
texts = [""] * 2
async for chunk in batch:
assert len(chunk.choices) == 1
choice = chunk.choices[0]
texts[choice.index] += choice.text
assert texts[0] == texts[1]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
completion = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
assert re.search(r"^" + prompt_text, completion.choices[0].text)
logprobs = completion.choices[0].logprobs
assert logprobs is not None
assert len(logprobs.text_offset) > 5
assert (len(logprobs.token_logprobs) > 5
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5

View File

@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from typing import List, Tuple
import pytest
import torch
from tests.kernels.utils import override_backend_env_variable
from tests.v1.sample.utils import (
assert_incr_detok_str_matches_non_incr_detok_str,
compute_correct_cumulative_logprob, get_test_batch)
from vllm import SamplingParams
from ...conftest import VllmRunner
MODEL = "meta-llama/Llama-3.2-1B"
DTYPE = "half"
@pytest.fixture(scope="module")
def vllm_model(vllm_runner):
with vllm_runner(
MODEL,
dtype=DTYPE,
max_logprobs=7,
# Very small number of batched tokens to ensure
# that we test chunking.
max_num_batched_tokens=16,
max_num_seqs=16,
max_model_len=128,
enforce_eager=True,
#TODO: enable this once we support it for
# prompt logprobs.
enable_prefix_caching=False,
gpu_memory_utilization=0.5,
) as vllm_model:
yield vllm_model
@pytest.fixture(scope="module")
def hf_model(hf_runner):
with hf_runner(MODEL, dtype=DTYPE) as hf_model:
yield hf_model
def _repeat_logprob_config(
test_prompts,
logprob_prompt_logprob_list: List[Tuple],
) -> List[Tuple]:
"""Ensure each test prompt has a logprob config.
A logprob config specifies the optional (i.e.
may-be-`None`) number of sample logprobs and
the optional number of prompt logprobs.
If more test prompts than logprob configs are
provided, the provided logprob configs are
tiled to match the number of test prompts.
If fewer test prompts than logprob configs
are provided, the list of logprob configs
is truncated to match the number of test
prompts.
Otherwise, the list of logprob configs
is returned as-is.
Args:
test_prompts: list of prompts under test
logprob_prompt_logprob_list: list of
(optional num sample logprob,
optional num prompt logprob)
tuples
Returns:
List of
(optional num sample logprob,optional num prompt logprob)
tuples which is either identical to
`logprob_prompt_logprob_list`, or else repeats
`logprob_prompt_logprob_list` enough times to match the
number of `test_prompts`, or else is truncated to match
the number of `test_prompts`
"""
num_test_prompts = len(test_prompts)
# Make sure there is a logprobs configuration for each test prompt
logprob_prompt_logprob_list = list(
itertools.islice(itertools.cycle(logprob_prompt_logprob_list),
num_test_prompts))
# Now the number of prompts should match the number of sample params combos
assert num_test_prompts == len(logprob_prompt_logprob_list)
return logprob_prompt_logprob_list
def _test_case_get_logprobs_and_prompt_logprobs(
hf_model,
vllm_model,
batch_logprobs_composition: str,
temperature: float,
example_prompts,
) -> None:
test_prompts = example_prompts
max_tokens = 5
hf_outputs = hf_model.generate_greedy(
test_prompts,
max_tokens=max_tokens,
)
hf_logprobs = hf_model.generate_greedy_logprobs(
test_prompts,
max_tokens=max_tokens,
)
# Batch has mixed sample params
# (different logprobs/prompt logprobs combos)
logprob_prompt_logprob_list = get_test_batch(batch_logprobs_composition)
# Ensure that each test prompt has a logprob config for testing
logprob_prompt_logprob_list = _repeat_logprob_config(
test_prompts, logprob_prompt_logprob_list)
# Generate SamplingParams
vllm_sampling_params = [
SamplingParams(max_tokens=max_tokens,
logprobs=num_lp,
prompt_logprobs=num_plp,
temperature=temperature,
seed=1984)
for num_lp, num_plp in logprob_prompt_logprob_list
]
vllm_results = vllm_model.model.generate(
test_prompts, sampling_params=vllm_sampling_params)
for vllm_result, hf_logprob, hf_output, logprob_prompt_logprob in zip(
vllm_results, hf_logprobs, hf_outputs,
logprob_prompt_logprob_list):
# Extract request-level (prompt)logprobs config
num_top_logprobs, num_top_prompt_logprobs = logprob_prompt_logprob
# Test whether sampled token output is consistent between vLLM and HF
# vLLM prompt+completion should match HF output
if temperature == 0.0:
assert (vllm_result.prompt_token_ids +
vllm_result.outputs[0].token_ids == hf_output[0])
else:
# Sampled tokens won't match if not greedy
assert (vllm_result.prompt_token_ids == hf_output[0]
[:len(vllm_result.prompt_token_ids)])
# Validate sample logprobs
if num_top_logprobs is not None:
assert num_top_logprobs is not None
# Confirm that the structure of the sample logprobs in the result is
# correct
assert vllm_result.outputs[0].logprobs is not None
assert len(vllm_result.outputs[0].logprobs) == max_tokens
for logprobs, token_id in zip(vllm_result.outputs[0].logprobs,
vllm_result.outputs[0].token_ids):
assert logprobs is not None
# Confirm that the output token appears among the logprobs
assert token_id in logprobs
token_in_topk = logprobs[token_id].rank <= num_top_logprobs
# If the output token is not included in the top K
# logprob, it can return 1 more data
if token_in_topk and num_top_logprobs != 0:
assert len(logprobs) == num_top_logprobs
else:
assert len(logprobs) == num_top_logprobs + 1
if num_top_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_logprobs + 1))
output_text = vllm_result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in vllm_result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens_lst.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens_lst)
assert_incr_detok_str_matches_non_incr_detok_str(
output_text, output_string_from_most_likely_tokens,
"The output text from the top logprob for each token "
"position should be the same as the output text in the "
"result.")
# Compare vLLM sample logprobs to HF
vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, sample_logprob in top_logprobs.items():
if temperature == 0.0 or i == 0:
logprob = sample_logprob.logprob
torch.testing.assert_close(
logprob,
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(
sample_logprob.decoded_token,
str), ("The token should be decoded by the time it is"
" returned to the user.")
# At this point we know the sample logprobs are correct for this
# request. Validate that cumulative_logprob is actually the sum.
# For each request, assert that the returned cumulative logprob
# matches the correct value, which is computed below.
torch.testing.assert_close(
vllm_result.outputs[0].cumulative_logprob,
compute_correct_cumulative_logprob(vllm_result.outputs[0]),
atol=1e-6,
rtol=1e-6)
else:
# Logprobs disabled for this request; should be None
assert vllm_result.outputs[0].logprobs is None
# Validate prompt logprobs
if num_top_prompt_logprobs is not None:
# Confirm that structure of prompt logprobs in result is correct
assert vllm_result.prompt_logprobs is not None
# - The first prompt logprob is always None
assert vllm_result.prompt_logprobs[0] is None
# - Prompt logprobs are returned for all indices in
# the prompt
assert len(vllm_result.prompt_logprobs) == len(
vllm_result.prompt_token_ids)
for prompt_logprobs, prompt_token_id in zip(
vllm_result.prompt_logprobs[1:],
vllm_result.prompt_token_ids[1:]):
assert prompt_logprobs is not None
# Confirm that the prompt token appears among the logprobs
assert prompt_token_id in prompt_logprobs
token_in_topk = prompt_logprobs[
prompt_token_id].rank <= num_top_prompt_logprobs
# If the prompt token is not included in the top K
# logprob, it can return 1 more data
if token_in_topk and num_top_prompt_logprobs != 0:
assert len(prompt_logprobs) == num_top_prompt_logprobs
else:
assert len(prompt_logprobs) == num_top_prompt_logprobs + 1
if num_top_prompt_logprobs > 0:
# We should have an entry for each of the topk ranks
all_ranks = {lp.rank for lp in prompt_logprobs.values()}
assert all(r in all_ranks
for r in range(1, num_top_prompt_logprobs + 1))
# Compare prompt logprobs to HF
# The first prompt logprob is always None, so we compare it from
# 1:.
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(
logprob.logprob,
hf_logprob[0][i][token_id].item(),
atol=2e-2,
rtol=2e-2)
else:
assert vllm_result.prompt_logprobs is None
#@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("batch_logprobs_composition",
["NONE", "SAMPLE", "PROMPT", "SAMPLE_PROMPT"])
@pytest.mark.parametrize("temperature", [0.0, 2.0])
def test_get_logprobs_and_prompt_logprobs(
hf_model,
vllm_model,
batch_logprobs_composition: str,
temperature: float,
example_prompts,
) -> None:
"""Test V1 Engine logprobs & prompt logprobs
Exercise a variety of combinations of `logprobs` and `prompt_logprobs`
settings and validate that
* The generated logprobs and prompt logprobs are consistent with the
configuration settings, in terms of whether or not the logprobs
(of either type) were requested and how many were requested
* The generated logprobs are consistent with the generated tokens
* The generated (prompt)logprobs are consistent with HuggingFace
(prompt)logprobs, as a reference
batch_logprobs_composition controls the logprobs configurations for
requests in the batch under test.
Args:
hf_model
vllm_model
batch_logprobs_composition: logprobs configuration for test batch
example_prompts
monkeypatch
"""
_test_case_get_logprobs_and_prompt_logprobs(
hf_model=hf_model,
vllm_model=vllm_model,
batch_logprobs_composition=batch_logprobs_composition,
temperature=temperature,
example_prompts=example_prompts)
def test_max_logprobs(monkeypatch):
"""vLLM v1 engine should fail a request with `logprobs > max_logprobs`
Should also fail for `prompt_logprobs > max_logprobs`
Args:
monkeypatch
"""
override_backend_env_variable(monkeypatch, "FLASH_ATTN")
runner = VllmRunner("facebook/opt-125m",
max_logprobs=1,
enable_prefix_caching=False,
max_model_len=256)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)
def test_none_logprobs(vllm_model, example_prompts, monkeypatch):
"""Engine should return `logprobs` and `prompt_logprobs` as `None`
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
max_tokens = 5
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
logprobs=None,
prompt_logprobs=None,
temperature=0.0)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)
for i in range(len(results_logprobs_none)):
# Check sample logprobs are None
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
# Check prompt logprobs are None
assert results_logprobs_none[i].prompt_logprobs is None
def test_zero_logprobs(vllm_model, example_prompts, monkeypatch):
"""Engine should return sampled token and prompt token logprobs
Args:
vllm_model: vLLM model fixture
example_prompts: list of example prompts (test fixture)
monkeypatch: supports editing env vars and rolling back changes
after the test
"""
max_tokens = 5
sampling_params_logprobs_zero = SamplingParams(max_tokens=max_tokens,
logprobs=0,
prompt_logprobs=0,
temperature=0.0)
results_logprobs_zero = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_zero)
for i in range(len(results_logprobs_zero)):
# Check that there is one sample logprob dict for each
# sample token
logprobs = results_logprobs_zero[i].outputs[0].logprobs
prompt_logprobs = results_logprobs_zero[i].prompt_logprobs
sampled_token_ids = results_logprobs_zero[i].outputs[0].token_ids
prompt_token_ids = results_logprobs_zero[i].prompt_token_ids
assert logprobs is not None
assert len(sampled_token_ids) == len(logprobs)
assert results_logprobs_zero[i].outputs[
0].cumulative_logprob is not None
# Check that there is one prompt logprob dict for each
# prompt token
assert prompt_logprobs is not None
assert len(prompt_token_ids) == len(prompt_logprobs)

View File

@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
import lm_eval
from ...utils import RemoteOpenAIServer
# arc-easy uses prompt_logprobs=1, logprobs=1
TASK = "arc_easy"
FILTER = "acc_norm,none"
RTOL = 0.03
EXPECTED_VALUE = 0.62
# FIXME(rob): enable prefix caching once supported.
MODEL = "meta-llama/Llama-3.2-1B"
MODEL_ARGS = f"pretrained={MODEL},enforce_eager=True,enable_prefix_caching=False" # noqa: E501
SERVER_ARGS = [
"--enforce_eager", "--no_enable_prefix_caching", "--disable-log-requests"
]
NUM_CONCURRENT = 100
def test_prompt_logprobs_e2e():
results = lm_eval.simple_evaluate(model="vllm",
model_args=MODEL_ARGS,
tasks=TASK,
batch_size="auto")
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
def test_promt_logprobs_e2e_server():
with RemoteOpenAIServer(MODEL, SERVER_ARGS) as remote_server:
url = f"{remote_server.url_for('v1')}/completions"
model_args = (
f"model={MODEL},"
f"base_url={url},"
f"num_concurrent={NUM_CONCURRENT},tokenized_requests=False")
results = lm_eval.simple_evaluate(
model="local-completions",
model_args=model_args,
tasks=TASK,
)
measured_value = results["results"][TASK][FILTER]
assert (measured_value - RTOL < EXPECTED_VALUE
and measured_value + RTOL > EXPECTED_VALUE
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"

120
tests/v1/sample/utils.py Normal file
View File

@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
import re
from typing import List, Tuple
from vllm import CompletionOutput
def get_test_batch(batch_logprobs_composition: str) -> List[Tuple]:
"""Generate logprobs configs for a batch of requests
A given request's logprobs configuration is (1) num_sample_logprobs and (2)
num_prompt_logprobs. The batch logprobs configuration is the list of request
logprobs configs.
batch_logprobs_composition == "NONE" yields a batch with no sample or prompt
logprobs
batch_logprobs_composition == "SAMPLE" yields a batch with some requests
configured for sample logprobs only, and others configured for no logprobs
batch_logprobs_composition == "PROMPT" yields a batch with some requests
configured for prompt logprobs only, and others configured for no logprobs
batch_logprobs_composition == "SAMPLE_PROMPT" yields a batch with some
requests configured for sample logprobs and prompt logprobs, some configured
for only sample logprobs or only prompt logprobs, and some configured for
no logprobs
Args:
batch_logprobs_composition: types of logprobs configs to include in batch
Returns:
List of (Optional[num_sample_logprobs], Optional[num_prompt_logprobs])
tuples
"""
if batch_logprobs_composition == "NONE":
# No requests with sample or prompt logprobs
return [(None, None)]
elif batch_logprobs_composition == "SAMPLE":
# Requests requiring sample logprobs or no logprobs
return [
(None, None),
(0, None),
(5, None),
(3, None),
]
elif batch_logprobs_composition == "PROMPT":
# Requests requiring prompt logprobs or no logprobs
return [
(None, None),
(None, 0),
(None, 6),
(None, 5),
]
elif batch_logprobs_composition == "SAMPLE_PROMPT":
# Requests requiring either no logprobs, just
# sample logprobs, just prompt logprobs, or
# both sample and prompt logprobs
return [
(None, None),
(0, None),
(5, None),
(3, None),
(0, 3),
(6, 0),
(6, 3),
(None, 6),
(None, 5),
(None, 0),
]
else:
raise ValueError("Invalid logprobs batch configuration for test.")
def assert_incr_detok_str_matches_non_incr_detok_str(
incremental_detokenization_str: str,
non_incremental_detokenization_str: str,
msg: str,
) -> None:
"""Compare incrementally detok. text to non-incrementally detok. text
Fail if the strings mismatch after non-alphanumeric characters are stripped
out.
Rationale: incremental detokenization in the text generation process allows
the tokenizer to adjust the next token text output based on the token's
context in the string. However, logprobs detokenization detokenizes each
token individually, and the resultant strings may include some
non-alphanumeric placeholder characters where there could be i.e.
whitespace. So, this function compares only the alphanumeric text
between two strings and fails if there is a mismatch, which helps
with validating logprobs detokenization.
Args:
incremental_detokenization_str: incrementally-detokenized generated text
non_incremental_detokenization_str: non-incrementally-detokenized logprob
tokens
msg: error message if `assert` fails
"""
rgx = r'[^a-zA-Z0-9]+'
assert (re.sub(rgx, '', incremental_detokenization_str) == re.sub(
rgx, '', non_incremental_detokenization_str)), (msg)
def compute_correct_cumulative_logprob(
completion_output: CompletionOutput) -> float:
"""Compute known-good value for evaluating cumulative logprob
Args:
completion_output: completion output from engine
Returns:
Known-good cumulative logprob value
"""
token_ids = completion_output.token_ids
logprobs = completion_output.logprobs
assert logprobs is not None
return sum([lp[tok_id].logprob for tok_id, lp in zip(token_ids, logprobs)])

View File

@ -142,6 +142,9 @@ class RequestOutput:
prompt_token_ids: Optional[List[int]],
text: str,
token_ids: List[int],
logprobs: Optional[SampleLogprobs],
prompt_logprobs: Optional[PromptLogprobs],
cumulative_logprob: Optional[float],
finished: bool = False,
) -> "RequestOutput":
"""Initialize a new RequestOutput object."""
@ -151,15 +154,14 @@ class RequestOutput:
index=0,
text=text,
token_ids=token_ids,
cumulative_logprob=None,
logprobs=None, # TODO
)
cumulative_logprob=cumulative_logprob,
logprobs=logprobs)
return RequestOutput(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
prompt_logprobs=None, # TODO
prompt_logprobs=prompt_logprobs,
outputs=[completion_output],
finished=finished,
)

View File

@ -74,6 +74,25 @@ def convert_prompt_ids_to_tokens(
return new_tokens, prefix_offset, read_offset
def convert_ids_list_to_tokens(
tokenizer: AnyTokenizer,
token_ids: List[int],
) -> List[str]:
"""Detokenize the input ids individually.
Args:
tokenizer: tokenizer used by model under test
token_ids: convert these tokens (Python list form)
Returns:
Python list of token string representations
"""
token_str_lst = tokenizer.convert_ids_to_tokens(token_ids)
_replace_none_with_empty(token_str_lst) # type: ignore
return token_str_lst
# Based on
# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15
# under Apache 2.0 license

View File

@ -437,6 +437,8 @@ class Scheduler:
) -> EngineCoreOutputs:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
outputs: List[EngineCoreOutput] = []
@ -471,6 +473,13 @@ class Scheduler:
self.encoder_cache_manager.free_encoder_input(
request, input_id)
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
stopped = False
new_logprobs = None
new_token_ids = None
if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
# NOTE(woosuk): Currently, we assume that each request
@ -486,20 +495,30 @@ class Scheduler:
if stopped:
self._free_request(request)
# Extract sample logprobs if needed.
if request.sampling_params.logprobs is not None:
assert logprobs is not None
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
new_token_ids = request.output_token_ids[-num_new_tokens:]
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
request_id=req_id,
new_token_ids=request.output_token_ids[-num_new_tokens:],
finished=request.is_finished(),
finish_reason=request.get_finished_reason(),
stop_reason=request.stop_reason)
outputs.append(output)
outputs.append(
EngineCoreOutput(
request_id=req_id,
new_token_ids=new_token_ids or [],
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
stop_reason=request.stop_reason))
# Breakout of the loop.
if stopped:
continue
if not stopped:
new_running.append(request)
new_running.append(request)
self.running = new_running
return EngineCoreOutputs(
outputs=outputs,

View File

@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional, Union
import msgspec
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
@ -67,10 +68,17 @@ class EngineCoreOutput(
request_id: str
new_token_ids: List[int]
finished: bool
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
@property
def finished(self) -> bool:
return self.finish_reason is not None
class EngineCoreOutputs(
msgspec.Struct,

View File

@ -11,7 +11,6 @@ from typing import List, Tuple, Type
import psutil
import zmq
import zmq.asyncio
from msgspec import msgpack
from vllm.config import VllmConfig
from vllm.logger import init_logger
@ -26,7 +25,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -292,7 +291,7 @@ class EngineCoreProc(EngineCore):
"""Output socket IO thread."""
# Msgpack serialization encoding.
encoder = msgpack.Encoder()
encoder = MsgpackEncoder()
# Reuse send buffer.
buffer = bytearray()

View File

@ -7,7 +7,6 @@ import weakref
from abc import ABC, abstractmethod
from typing import List, Optional, Type
import msgspec
import zmq
import zmq.asyncio
@ -20,7 +19,7 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, PickleEncoder
from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
@ -163,7 +162,7 @@ class MPClient(EngineCoreClient):
# Serialization setup.
self.encoder = PickleEncoder()
self.decoder = msgspec.msgpack.Decoder(EngineCoreOutputs)
self.decoder = MsgpackDecoder(EngineCoreOutputs)
# ZMQ setup.
self.ctx = (

View File

@ -1,27 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine import EngineCoreRequest
logger = init_logger(__name__)
@dataclass
class DetokenizerOutput:
output_text: str
token_ids: List[int]
finished: bool
finish_reason: Optional[FinishReason] = None
stop_reason: Union[int, str, None] = None
@dataclass
class IncrementalDetokenizer:
@ -42,7 +32,6 @@ class IncrementalDetokenizer:
# Parameters for detokenization
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind
# Tokenizer for this request
tokenizer: AnyTokenizer
@ -90,25 +79,19 @@ class IncrementalDetokenizer:
skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens,
output_kind=request.sampling_params.output_kind,
prompt_len=len(request.prompt_token_ids),
tokenizer=tokenizer,
stop_buffer_length=stop_buffer_length,
)
def update_from_output(
self,
output: EngineCoreOutput,
) -> Optional[DetokenizerOutput]:
def update(self, new_token_ids: List[int]) -> Optional[str]:
"""
Update RequestState for the request_id by:
1) Detokenize the new token ids incrementally.
2) Update the RequestOutput with the new text.
"""
2) Evaluate stop criteria.
new_token_ids = output.new_token_ids
finish_reason = output.finish_reason
stop_reason = output.stop_reason
Return matched stop string or None.
"""
# 1) Detokenize the new token ids incrementally.
# TODO(woosuk): This method becomes very inefficient when the number of
@ -131,11 +114,13 @@ class IncrementalDetokenizer:
self.tokens.extend(new_tokens)
self.prefix_offset = prefix_offset
self.read_offset = read_offset
self.output_text += new_decoded_token_text
decoded_text += new_decoded_token_text
self.output_text += decoded_text
# 2) Evaluate stop criteria.
stop_string = None
if self.stop:
stop = StopChecker.check_stop_strings(
output_text=self.output_text,
@ -144,28 +129,13 @@ class IncrementalDetokenizer:
include_in_output=self.include_stop_str_in_output,
)
if stop is not None:
stop_str, truncate_to = stop
stop_string, truncate_to = stop
if truncate_to != -1:
self.output_text = self.output_text[:truncate_to]
finish_reason = FinishReason.STOP
stop_reason = stop_str
# TODO: handle stop_token_ids here too?
return stop_string
# 3) Update the RequestOutput object with the new text.
finished = finish_reason is not None
if self.output_kind == RequestOutputKind.FINAL_ONLY \
and not finished:
return None
delta = self.output_kind == RequestOutputKind.DELTA
output_text = self._get_next_output_text(finished, delta)
token_ids = new_token_ids if delta else self.output_token_ids
return DetokenizerOutput(output_text, token_ids, finished,
finish_reason, stop_reason)
def _get_next_output_text(self, finished: bool, delta: bool) -> str:
def get_next_output_text(self, finished: bool, delta: bool) -> str:
"""If delta is True, only new text since the last call to
this method is returned"""

View File

@ -45,6 +45,7 @@ class LLMEngine:
multiprocess_mode: bool = False,
) -> None:
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
# Tokenizer (+ ensure liveness if running in another process).
self.tokenizer = init_tokenizer_from_configs(

194
vllm/v1/engine/logprobs.py Normal file
View File

@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
import itertools
from dataclasses import dataclass
from typing import Dict, List, Optional
from vllm.logger import init_logger
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_ids_list_to_tokens)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
logger = init_logger(__name__)
@dataclass
class LogprobsProcessor:
# Tokenizer for this request
tokenizer: AnyTokenizer
# Logprobs for this request
logprobs: Optional[SampleLogprobs]
prompt_logprobs: Optional[PromptLogprobs]
cumulative_logprob: Optional[float]
num_logprobs: Optional[int]
num_prompt_logprobs: Optional[int]
@classmethod
def from_new_request(
cls,
tokenizer: AnyTokenizer,
request: EngineCoreRequest,
) -> "LogprobsProcessor":
num_logprobs = request.sampling_params.logprobs
num_prompt_logprobs = request.sampling_params.prompt_logprobs
return cls(
tokenizer=tokenizer,
cumulative_logprob=(None if num_logprobs is None else 0.),
logprobs=(None if num_logprobs is None else []),
# NOTE: logprob of first prompt token is None.
prompt_logprobs=(None if num_prompt_logprobs is None else [None]),
num_prompt_logprobs=num_prompt_logprobs,
num_logprobs=num_logprobs,
)
def _update_sample_logprobs(self, logprobs_lists: LogprobsLists) -> None:
"""Update with sample logprobs from EngineCore.
Outer lists are only of len > 1 if EngineCore made
>1 tokens in prior step (e.g. in spec decoding).
Args:
logprobs_lists: the lists of logprob tokens, logprobs, and ranks.
"""
assert self.num_logprobs is not None
assert self.logprobs is not None
assert self.cumulative_logprob is not None
token_ids_lst, logprobs_lst, ranks_lst = logprobs_lists
for rank, logprobs, token_ids in zip(ranks_lst, logprobs_lst,
token_ids_lst):
# Detokenize (non-incrementally).
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer, token_ids)
# Sampler puts the sampled logprob in first.
sampled_token_logprob = logprobs[0]
self.cumulative_logprob += sampled_token_logprob
# Update with the Logprob dictionary for this pos.
self.logprobs.append(
self._make_logprob_dict(
logprobs,
token_ids,
decoded_tokens,
rank,
self.num_logprobs,
))
def _update_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
) -> None:
"""Update with prompt logprobs from EngineCore.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
# Prompt logprobs are enabled.
assert self.num_prompt_logprobs is not None
assert self.prompt_logprobs is not None
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = convert_ids_list_to_tokens(
self.tokenizer,
token_ids.flatten().tolist())
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
self.prompt_logprobs.append(
self._make_logprob_dict(prompt_logprobs[pos], token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
self.num_prompt_logprobs))
def pop_prompt_logprobs(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs
The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.
Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_logprobs
if plp:
self.prompt_logprobs = []
return plp
@staticmethod
def _make_logprob_dict(
logprobs: List[float],
logprob_token_ids: List[int],
decoded_tokens: List[str],
rank: int,
num_logprobs: int,
) -> Dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
Dict[token id, Logprob]
"""
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank, ), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(
logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def update_from_output(self, output: EngineCoreOutput) -> None:
if output.new_logprobs is not None:
self._update_sample_logprobs(output.new_logprobs)
if output.new_prompt_logprobs_tensors is not None:
self._update_prompt_logprobs(output.new_prompt_logprobs_tensors)

View File

@ -5,11 +5,12 @@ from dataclasses import dataclass
from typing import Dict, List, Optional
from vllm.outputs import RequestOutput
from vllm.transformers_utils.detokenizer_utils import AnyTokenizer
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.engine.detokenizer import (DetokenizerOutput,
IncrementalDetokenizer)
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.metrics.stats import IterationStats, RequestStateStats
@ -26,16 +27,20 @@ class RequestState:
def __init__(
self,
request_id: str,
output_kind: RequestOutputKind,
prompt: Optional[str],
prompt_token_ids: List[int],
logprobs_processor: LogprobsProcessor,
detokenizer: IncrementalDetokenizer,
arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]],
):
self.request_id = request_id
self.output_kind = output_kind
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_len = len(prompt_token_ids)
self.logprobs_processor = logprobs_processor
self.detokenizer = detokenizer
self.is_prefilling = True
self.queue = queue
@ -51,8 +56,13 @@ class RequestState:
) -> "RequestState":
return cls(
request_id=request.request_id,
output_kind=request.sampling_params.output_kind,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
logprobs_processor=LogprobsProcessor.from_new_request(
tokenizer=tokenizer,
request=request,
),
detokenizer=IncrementalDetokenizer.from_new_request(
tokenizer=tokenizer,
request=request,
@ -127,13 +137,8 @@ class OutputProcessor:
batch to ensure system overheads are minimized. This is the
only function that should loop over EngineCoreOutputs.
If you need to touch every element of the batch, implement a
method called XXXClass.update_from_output() to be called
within the loop below. For examples, see:
* IterationStats.update_from_output()
* Detokenizer.update_from_output()
TODO(rob): add Protocol makes update_from_output explicit.
If you need to touch every element of the batch, do it from
within the loop below.
**********************************************************
"""
@ -154,17 +159,37 @@ class OutputProcessor:
req_state.is_prefilling,
req_state.prompt_len,
req_state.stats)
req_state.is_prefilling = False
# 2) Detokenize the token ids into text.
detokenizer_output = req_state.detokenizer.update_from_output(
engine_core_output)
new_token_ids = engine_core_output.new_token_ids
finish_reason = engine_core_output.finish_reason
# 3) Create and handle RequestOutput objects.
if detokenizer_output is not None:
request_output = self._make_request_output(
req_state, detokenizer_output)
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state.is_prefilling = not new_token_ids
# 2) Detokenize the token ids into text and check for stop
# strings.
stop_reason = req_state.detokenizer.update(new_token_ids)
if stop_reason:
finish_reason = FinishReason.STOP
# 3) Compute sample and prompt logprobs for request,
# if required.
req_state.logprobs_processor.update_from_output(engine_core_output)
# 4) Create and handle RequestOutput objects.
if request_output := self._make_request_output(
req_state, new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output)
@ -174,18 +199,16 @@ class OutputProcessor:
# Free completed requests.
if request_output.finished:
assert detokenizer_output.finish_reason is not None
self.request_states.pop(req_id)
if not engine_core_output.finished:
# If req not finished in EngineCore, but Detokenizer
# detected stop string, abort needed in EngineCore.
reqs_to_abort.append(req_id)
# Track per-request stats
# Track per-request stats.
assert finish_reason is not None
iteration_stats.update_from_finished_request(
detokenizer_output.finish_reason, request_output,
req_state.stats)
finish_reason, request_output, req_state.stats)
return OutputProcessorOutput(
request_outputs=request_outputs,
@ -196,20 +219,47 @@ class OutputProcessor:
@staticmethod
def _make_request_output(
request_state: RequestState,
detokenizer_output: DetokenizerOutput,
) -> RequestOutput:
new_token_ids: List[int],
finish_reason: Optional[FinishReason],
stop_reason: Optional[str],
) -> Optional[RequestOutput]:
finished = finish_reason is not None
output_kind = request_state.output_kind
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if not finished and (request_state.is_prefilling
or output_kind == RequestOutputKind.FINAL_ONLY):
# Only the final output is required in FINAL_ONLY mode.
return None
detokenizer = request_state.detokenizer
logprobs_processor = request_state.logprobs_processor
delta = output_kind == RequestOutputKind.DELTA
logprobs = logprobs_processor.logprobs
if delta:
if logprobs:
logprobs = logprobs[-len(new_token_ids):]
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = logprobs_processor.pop_prompt_logprobs()
else:
prompt_logprobs = logprobs_processor.prompt_logprobs
request_output = RequestOutput.new(
request_state.request_id,
request_state.prompt,
request_state.prompt_token_ids,
detokenizer_output.output_text,
detokenizer_output.token_ids,
detokenizer_output.finished,
request_id=request_state.request_id,
prompt=request_state.prompt,
prompt_token_ids=request_state.prompt_token_ids,
text=detokenizer.get_next_output_text(finished, delta),
token_ids=new_token_ids if delta else detokenizer.output_token_ids,
logprobs=logprobs,
prompt_logprobs=prompt_logprobs,
cumulative_logprob=logprobs_processor.cumulative_logprob,
finished=finished,
)
if detokenizer_output.finished:
if finished:
completion_output = request_output.outputs[0]
completion_output.finish_reason = str(
detokenizer_output.finish_reason)
completion_output.stop_reason = detokenizer_output.stop_reason
completion_output.finish_reason = str(finish_reason)
completion_output.stop_reason = stop_reason
return request_output

View File

@ -33,6 +33,7 @@ class Processor:
):
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.tokenizer = tokenizer
@ -51,6 +52,37 @@ class Processor:
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
def _validate_logprobs(
self,
params: Union[SamplingParams, PoolingParams],
) -> None:
if not isinstance(params, SamplingParams):
return
max_logprobs = self.model_config.max_logprobs
# Validate sample logprobs.
if params.logprobs and params.logprobs > max_logprobs:
raise ValueError(
f"Requested sample logprobs of {params.logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# Validate prompt logprobs.
if params.prompt_logprobs and params.prompt_logprobs > max_logprobs:
raise ValueError(
f"Requested prompt logprobs of {params.prompt_logprobs}, "
f"which is greater than max allowed: {max_logprobs}")
# TODO(andy): enable this in follow up by recomputing.
if (params.prompt_logprobs is not None
and self.cache_config.enable_prefix_caching):
raise ValueError("Prefix caching with prompt logprobs not yet "
"supported on VLLM V1.")
def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def process_inputs(
self,
request_id: str,
@ -64,12 +96,11 @@ class Processor:
) -> EngineCoreRequest:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models.
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
self._validate_logprobs(params)
self._validate_lora(lora_request)
if arrival_time is None:
arrival_time = time.time()
assert priority == 0, "vLLM V1 does not support priority at the moment."

View File

@ -60,14 +60,17 @@ class IterationStats:
self.num_generation_tokens += num_new_generation_tokens
if is_prefilling:
# This relies on the invariant that EngineCore does
# not stream outputs for partially completed prefills
# (scheduler.update_from_output makes EngineCoreOutput
# iff num_computed_tokens == num_tokens).
assert (num_new_generation_tokens > 0)
self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if num_new_generation_tokens > 0:
self.num_prompt_tokens += prompt_len
self.time_to_first_tokens_iter.append(last_token_latency)
else:
self.time_per_output_tokens_iter.append(last_token_latency)

View File

@ -1,25 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Dict, List, Optional
from typing import Dict, List, NamedTuple, Optional
import torch
class LogprobsLists(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: List[List[int]]
# [num_reqs, max_num_logprobs + 1]
logprobs: List[List[float]]
# [num_reqs]
sampled_token_ranks: List[int]
def slice(self, start: int, end: int):
return LogprobsLists(
self.logprob_token_ids[start:end],
self.logprobs[start:end],
self.sampled_token_ranks[start:end],
)
class LogprobsTensors(NamedTuple):
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprobs: torch.Tensor
# [num_reqs]
selected_token_ranks: torch.Tensor
def tolists(self):
return LogprobsLists(
self.logprob_token_ids.tolist(),
self.logprobs.tolist(),
self.selected_token_ranks.tolist(),
)
@dataclass
class SamplerOutput:
# [num_reqs]
sampled_token_ids: torch.Tensor
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs: Optional[torch.Tensor]
# TODO: Support prompt logprobs.
prompt_logprob_token_ids: Optional[torch.Tensor]
prompt_logprobs: Optional[torch.Tensor]
logprobs_tensors: Optional[LogprobsTensors]
# ModelRunnerOutput is serialized and sent to the scheduler process.
@ -36,6 +62,12 @@ class ModelRunnerOutput:
sampled_token_ids: List[int]
# [num_reqs, max_num_logprobs + 1]
logprob_token_ids_cpu: Optional[torch.Tensor]
# [num_reqs, max_num_logprobs + 1]
logprobs_cpu: Optional[torch.Tensor]
# [num_reqs]
logprobs: Optional[LogprobsLists]
# req_id -> (token_ids, logprobs, ranks)
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]

View File

@ -20,7 +20,8 @@ class SamplingMetadata:
generators: Dict[int, torch.Generator]
max_num_logprobs: int
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: Optional[int]
no_penalties: bool
prompt_token_ids: Optional[torch.Tensor]

View File

@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
"""A layer that samples the next tokens from the model's outputs."""
from typing import Tuple
import torch
import torch.nn as nn
from vllm.v1.outputs import SamplerOutput
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
apply_min_token_penalties)
@ -25,20 +24,16 @@ class Sampler(nn.Module):
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
needs_logprobs = sampling_metadata.max_num_logprobs > 0
if needs_logprobs:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# NOTE: We compute logprobs first because the below ops may
# modify the logits tensor in-place (and we don't want to clone
# the logits tensor for memory efficiency).
topk_logprobs, topk_indices = self.get_topk_logprobs(
logits, sampling_metadata)
else:
topk_logprobs = None
topk_indices = None
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# TODO(rob): provide option for logprobs post sampling.
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
num_logprobs = sampling_metadata.max_num_logprobs
if num_logprobs is not None:
raw_logprobs = self.compute_logprobs(logits)
# Use float32 for the logits.
logits = logits.to(torch.float32)
@ -48,15 +43,19 @@ class Sampler(nn.Module):
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Gather the logprobs of the topk and sampled token (if requested).
# Get logprobs and rank tensors (if requested)
logprobs_tensors = None if num_logprobs is None else \
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
prompt_logprobs=None,
logprobs_tensors=logprobs_tensors,
)
return sampler_output
@ -103,19 +102,52 @@ class Sampler(nn.Module):
)
return sampled
def get_topk_logprobs(
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
logprobs = logits.log_softmax(dim=-1, dtype=torch.float32)
# FIXME: Mask the sampled token_id, get topk logprobs,
# and concatenate the topk with the sampled token_id.
topk_logprobs, topk_indices = torch.topk(
logprobs, sampling_metadata.max_num_logprobs, dim=-1)
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
topk_indices = topk_indices.to(torch.int32)
return topk_logprobs, topk_indices
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_penalties(
self,

View File

@ -1,12 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
from typing import Any
import torch
from msgspec import msgpack
CUSTOM_TYPE_CODE_PICKLE = 1
class PickleEncoder:
def encode(self, obj):
def encode(self, obj: Any):
return pickle.dumps(obj)
def decode(self, data):
def decode(self, data: Any):
return pickle.loads(data)
class MsgpackEncoder:
"""Encoder with custom torch tensor serialization."""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=custom_enc_hook)
def encode(self, obj: Any) -> bytes:
return self.encoder.encode(obj)
def encode_into(self, obj: Any, buf: bytearray) -> None:
self.encoder.encode_into(obj, buf)
class MsgpackDecoder:
"""Decoder with custom torch tensor serialization."""
def __init__(self, t: Any):
self.decoder = msgpack.Decoder(t, ext_hook=custom_ext_hook)
def decode(self, obj: Any):
return self.decoder.decode(obj)
def custom_enc_hook(obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
# NOTE(rob): it is fastest to use numpy + pickle
# when serializing torch tensors.
# https://gist.github.com/tlrmchlsmth/8067f1b24a82b6e2f90450e7764fa103 # noqa: E501
return msgpack.Ext(CUSTOM_TYPE_CODE_PICKLE, pickle.dumps(obj.numpy()))
raise NotImplementedError(f"Objects of type {type(obj)} are not supported")
def custom_ext_hook(code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_CODE_PICKLE:
return torch.from_numpy(pickle.loads(data))
raise NotImplementedError(f"Extension type code {code} is not supported")

View File

@ -176,7 +176,9 @@ class InputBatch:
self.generators: Dict[int, torch.Generator] = {}
self.num_logprobs: Dict[str, int] = {}
self.prompt_logprob_reqs: Set[str] = set()
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}
def add_request(
self,
@ -238,11 +240,10 @@ class InputBatch:
if request.generator is not None:
self.generators[req_index] = request.generator
num_logprobs = sampling_params.logprobs
if num_logprobs is not None and num_logprobs > 0:
self.num_logprobs[req_id] = num_logprobs
if sampling_params.prompt_logprobs:
self.prompt_logprob_reqs.add(req_id)
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
# Add request lora ID
if request.lora_request:
@ -272,7 +273,7 @@ class InputBatch:
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.prompt_logprob_reqs.discard(req_id)
self.num_prompt_logprobs.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
@ -297,7 +298,7 @@ class InputBatch:
self.repetition_penalties_reqs.clear()
self.generators.clear()
self.num_logprobs.clear()
self.prompt_logprob_reqs.clear()
self.num_prompt_logprobs.clear()
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
@ -489,13 +490,9 @@ class InputBatch:
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> int:
return max(self.num_logprobs.values()) if self.num_logprobs else 0
@property
def no_logprob(self) -> bool:
return len(self.num_logprobs) == 0
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return len(self.prompt_logprob_reqs) == 0
return not self.num_prompt_logprobs

View File

@ -29,7 +29,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_mapper import MMInputMapperClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -804,8 +804,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
inputs_embeds=inputs_embeds,
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(batch_changed)
@ -818,7 +818,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
for i, req_id in enumerate( # type: ignore[assignment]
self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
@ -847,27 +848,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
logprobs_tensors = sampler_output.logprobs_tensors
logprobs_lists = logprobs_tensors.tolists() \
if logprobs_tensors is not None else None
# Compute prompt logprobs if needed.
prompt_logprobs_dict = self._get_prompt_logprobs_dict(
hidden_states,
scheduler_output,
)
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
)
return model_runner_output
@ -886,6 +888,76 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def _get_prompt_logprobs_dict(
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.
completed_prefill_reqs = []
for req_id, num_prompt_logprobs in num_prompt_logprobs_dict.items():
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# Get metadata for this request.
request = self.requests[req_id]
num_prompt_tokens = len(request.prompt_token_ids)
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
self.device, non_blocking=True)
# Determine number of logits to retrieve.
start_tok = request.num_computed_tokens + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens < num_remaining_tokens:
# This is a chunk, more tokens remain.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(req_id)
# Get the logits corresponding to this req's prompt tokens.
# If this is a partial request (i.e. chunked prefill),
# then there is prompt logprob generated for each index.
req_idx = self.input_batch.req_id_to_index[req_id]
offset = self.query_start_loc_np[req_idx].item()
prompt_hidden_states = hidden_states[offset:offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states, None)
# Get the "target" tokens for each index. For prompt at index i,
# the token at prompt index i+1 is the "sampled" token we want
# to gather the logprob for.
tgt_token_ids = prompt_token_ids[start_tok:start_tok + num_logits]
# Compute prompt logprobs.
logprobs = self.model.sampler.compute_logprobs(logits)
token_ids, logprobs, ranks = self.model.sampler.gather_logprobs(
logprobs, num_prompt_logprobs, tgt_token_ids)
# Transfer GPU->CPU async.
prompt_logprobs_dict[req_id] = LogprobsTensors(
token_ids.to("cpu", non_blocking=True),
logprobs.to("cpu", non_blocking=True),
ranks.to("cpu", non_blocking=True),
)
# Remove requests that have completed prefill from the batch
# num_prompt_logprobs_dict.
for req_id in completed_prefill_reqs:
del num_prompt_logprobs_dict[req_id]
# Must synchronize the non-blocking GPU->CPU transfers.
torch.cuda.synchronize()
return prompt_logprobs_dict
@torch.inference_mode()
def _dummy_run(
self,