Push logprob generation to LLMEngine (#3065)

Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum 2024-03-04 11:54:06 -08:00 committed by GitHub
parent 76e8a70476
commit 22de45235c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 551 additions and 331 deletions

View File

@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
logprobs=True, logprobs=True,
top_logprobs=10) top_logprobs=5)
assert chat_completion.id is not None assert chat_completion.id is not None
assert chat_completion.choices is not None and len( assert chat_completion.choices is not None and len(
chat_completion.choices) == 1 chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10 assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
message = chat_completion.choices[0].message message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10 assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant" assert message.role == "assistant"
@ -229,7 +229,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
# test multi-turn dialogue # test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"}) messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create( chat_completion = await client.chat.completions.create(
model=MODEL_NAME, model=model_name,
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
) )
@ -237,6 +237,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# Default max_logprobs is 5, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=True)
async for chunk in stream:
...
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=False)
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...
with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)
# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
stream=False)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",

View File

@ -1,5 +1,6 @@
import pytest import pytest
import torch import torch
from tests.conftest import VllmRunner
from vllm import SamplingParams from vllm import SamplingParams
@ -16,6 +17,7 @@ def test_get_prompt_logprobs(
example_prompts, example_prompts,
): ):
max_tokens = 5 max_tokens = 5
num_top_logprobs = 6
hf_model = hf_runner(model, dtype=dtype) hf_model = hf_runner(model, dtype=dtype)
hf_logprobs = hf_model.generate_greedy_logprobs( hf_logprobs = hf_model.generate_greedy_logprobs(
example_prompts, example_prompts,
@ -23,19 +25,32 @@ def test_get_prompt_logprobs(
) )
del hf_model del hf_model
vllm_model = vllm_runner(model, dtype=dtype) vllm_model = vllm_runner(model, dtype=dtype, max_logprobs=num_top_logprobs)
vllm_sampling_params = SamplingParams(max_tokens=max_tokens, vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=5, logprobs=num_top_logprobs,
prompt_logprobs=5, prompt_logprobs=5,
temperature=0.0) temperature=0.0)
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params) example_prompts, sampling_params=vllm_sampling_params)
del vllm_model
# Test whether logprobs are included in the results. # Test whether logprobs are included in the results.
for result in vllm_results: for result in vllm_results:
assert result.prompt_logprobs is not None assert result.prompt_logprobs is not None
assert result.outputs[0].logprobs is not None assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
output_text = result.outputs[0].text
output_string_from_most_likely_tokens = []
for top_logprobs in result.outputs[0].logprobs:
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
# Test whether prompt logprobs are consistent with HF # Test whether prompt logprobs are consistent with HF
for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs): for vllm_result, hf_logprob in zip(vllm_results, hf_logprobs):
@ -43,14 +58,29 @@ def test_get_prompt_logprobs(
vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:] vllm_prompt_logprobs = vllm_result.prompt_logprobs[1:]
for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs): for i, vllm_prompt_logprob_dict in enumerate(vllm_prompt_logprobs):
for token_id, logprob in vllm_prompt_logprob_dict.items(): for token_id, logprob in vllm_prompt_logprob_dict.items():
torch.testing.assert_close(logprob, torch.testing.assert_close(logprob.logprob,
hf_logprob[0][i][token_id].item(), hf_logprob[0][i][token_id].item(),
atol=1e-2, atol=1e-2,
rtol=1e-2) rtol=1e-2)
vllm_sample_logprobs = vllm_result.outputs[0].logprobs vllm_sample_logprobs = vllm_result.outputs[0].logprobs
for i, vllm_sample_logprob_dict in enumerate(vllm_sample_logprobs): for i, top_logprobs in enumerate(vllm_sample_logprobs):
for token_id, logprob in vllm_sample_logprob_dict.items(): for token_id, sample_logprob in top_logprobs.items():
logprob = sample_logprob.logprob
torch.testing.assert_close(logprob, torch.testing.assert_close(logprob,
hf_logprob[i][-1][token_id].item(), hf_logprob[i][-1][token_id].item(),
atol=1e-2, atol=1e-2,
rtol=1e-2) rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), \
("The token should be decoded by the time it is returned "
" to the user.")
def test_max_logprobs():
runner = VllmRunner("facebook/opt-125m", max_logprobs=1)
vllm_sampling_params = SamplingParams(logprobs=1)
# should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Dict
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.utils import get_distributed_init_method, get_ip, get_open_port
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.sequence import SequenceGroupMetadata, SequenceData from vllm.sequence import Logprob, SequenceGroupMetadata, SequenceData
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
@ -166,13 +166,15 @@ def create_seq_group_metadata_from_prompts(
def assert_logprobs_dict_allclose( def assert_logprobs_dict_allclose(
actual_logprobs: List[Dict[int, float]], actual_logprobs: List[Dict[int, Logprob]],
expected_logprobs: List[Dict[int, float]]) -> None: expected_logprobs: List[Dict[int, Logprob]]) -> None:
for single_step_actual_logprobs, single_step_expected_logprobs in zip( for single_step_actual_logprobs, single_step_expected_logprobs in zip(
actual_logprobs, expected_logprobs): actual_logprobs, expected_logprobs):
assert set(single_step_actual_logprobs.keys()) == set( assert set(single_step_actual_logprobs.keys()) == set(
single_step_expected_logprobs.keys()) single_step_expected_logprobs.keys())
for token_id in single_step_actual_logprobs: for token_id in single_step_actual_logprobs:
actual = torch.tensor(single_step_actual_logprobs[token_id]) actual = torch.tensor(
expected = torch.tensor(single_step_expected_logprobs[token_id]) single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected) assert torch.allclose(actual, expected)

View File

@ -79,6 +79,7 @@ class ModelConfig:
quantization: Optional[str] = None, quantization: Optional[str] = None,
enforce_eager: bool = False, enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None, max_context_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -93,6 +94,7 @@ class ModelConfig:
self.quantization = quantization self.quantization = quantization
self.enforce_eager = enforce_eager self.enforce_eager = enforce_eager
self.max_context_len_to_capture = max_context_len_to_capture self.max_context_len_to_capture = max_context_len_to_capture
self.max_logprobs = max_logprobs
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
# download model from ModelScope hub, # download model from ModelScope hub,

View File

@ -31,6 +31,7 @@ class EngineArgs:
max_num_batched_tokens: Optional[int] = None max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256 max_num_seqs: int = 256
max_paddings: int = 256 max_paddings: int = 256
max_logprobs: int = 5 # OpenAI default value
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
code_revision: Optional[str] = None code_revision: Optional[str] = None
@ -212,6 +213,12 @@ class EngineArgs:
type=int, type=int,
default=EngineArgs.max_paddings, default=EngineArgs.max_paddings,
help='maximum number of paddings in a batch') help='maximum number of paddings in a batch')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('max number of log probs to return logprobs is specified in'
' SamplingParams'))
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
@ -300,7 +307,8 @@ class EngineArgs:
self.trust_remote_code, self.download_dir, self.load_format, self.trust_remote_code, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.code_revision, self.dtype, self.seed, self.revision, self.code_revision,
self.tokenizer_revision, self.max_model_len, self.quantization, self.tokenizer_revision, self.max_model_len, self.quantization,
self.enforce_eager, self.max_context_len_to_capture) self.enforce_eager, self.max_context_len_to_capture,
self.max_logprobs)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,

View File

@ -47,7 +47,7 @@ class AsyncStream:
self._queue = asyncio.Queue() self._queue = asyncio.Queue()
self._finished = False self._finished = False
def put(self, item: RequestOutput) -> None: def put(self, item: Union[RequestOutput, Exception]) -> None:
if self._finished: if self._finished:
return return
self._queue.put_nowait(item) self._queue.put_nowait(item)
@ -110,6 +110,17 @@ class RequestTracker:
logger.info(f"Finished request {request_id}.") logger.info(f"Finished request {request_id}.")
self.abort_request(request_id) self.abort_request(request_id)
def process_exception(self,
request_id: str,
exception: Exception,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info(f"Finished request {request_id}.")
self.abort_request(request_id)
def add_request(self, request_id: str, def add_request(self, request_id: str,
**engine_add_request_kwargs) -> AsyncStream: **engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background """Add a request to be sent to the engine on the next background
@ -377,10 +388,18 @@ class AsyncLLMEngine:
for new_request in new_requests: for new_request in new_requests:
# Add the request into the vLLM engine's waiting queue. # Add the request into the vLLM engine's waiting queue.
# TODO: Maybe add add_request_batch to reduce Ray overhead # TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray: if self.engine_use_ray:
await self.engine.add_request.remote(**new_request) await self.engine.add_request.remote(**new_request)
else: else:
await self.engine.add_request_async(**new_request) await self.engine.add_request_async(**new_request)
except ValueError as e:
# TODO: use a vLLM specific error for failed validation
self._request_tracker.process_exception(
new_request["request_id"],
e,
verbose=self.log_requests,
)
if finished_requests: if finished_requests:
await self._engine_abort(finished_requests) await self._engine_abort(finished_requests)

View File

@ -18,7 +18,7 @@ from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (Logprob, SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup) TokenizerGroup)
@ -473,6 +473,13 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.monotonic()
prompt_token_ids = self.encode_request( prompt_token_ids = self.encode_request(
@ -583,6 +590,13 @@ class LLMEngine:
# Process prompt logprobs # Process prompt logprobs
prompt_logprobs = outputs.prompt_logprobs prompt_logprobs = outputs.prompt_logprobs
if prompt_logprobs is not None: if prompt_logprobs is not None:
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
all_token_ids = seq.get_token_ids()
for i, prompt_logprobs_for_token in enumerate(prompt_logprobs):
self._decode_logprobs(seq, seq_group.sampling_params,
prompt_logprobs_for_token,
all_token_ids[:i])
seq_group.prompt_logprobs = prompt_logprobs seq_group.prompt_logprobs = prompt_logprobs
# Process samples # Process samples
@ -930,12 +944,36 @@ class LLMEngine:
time_e2e_requests=time_e2e_requests, time_e2e_requests=time_e2e_requests,
) )
def _decode_logprobs(self, seq: Sequence, prms: SamplingParams,
logprobs: Dict[int, Logprob],
all_input_ids: List[int]) -> None:
if not logprobs:
return
for token_id, sample_logprob in logprobs.items():
if (sample_logprob.decoded_token is None and token_id != -1):
all_input_ids_with_logprob = all_input_ids[:-1] + [token_id]
_, new_text, prefix_offset, read_offset = detokenize_incrementally(
self.get_tokenizer_for_seq(seq),
all_input_ids=all_input_ids_with_logprob,
prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset,
skip_special_tokens=prms.skip_special_tokens,
spaces_between_special_tokens=prms.
spaces_between_special_tokens,
)
sample_logprob.decoded_token = new_text
def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
"""Decodes the new token for a sequence.""" """Decodes the new token for a sequence."""
all_input_ids = seq.get_token_ids()
self._decode_logprobs(seq, prms, seq.output_logprobs[-1],
all_input_ids)
(new_tokens, new_output_text, prefix_offset, (new_tokens, new_output_text, prefix_offset,
read_offset) = detokenize_incrementally( read_offset) = detokenize_incrementally(
self.get_tokenizer_for_seq(seq), self.get_tokenizer_for_seq(seq),
all_input_ids=seq.get_token_ids(), all_input_ids=all_input_ids,
prev_tokens=seq.tokens, prev_tokens=seq.tokens,
prefix_offset=seq.prefix_offset, prefix_offset=seq.prefix_offset,
read_offset=seq.read_offset, read_offset=seq.read_offset,

View File

@ -82,8 +82,12 @@ class OpenAIServingChat(OpenAIServing):
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(
request, result_generator, request_id) request, result_generator, request_id)
else: else:
try:
return await self.chat_completion_full_generator( return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id) request, raw_request, result_generator, request_id)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
def get_chat_request_role(self, request: ChatCompletionRequest) -> str: def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt: if request.add_generation_prompt:
@ -99,7 +103,19 @@ class OpenAIServingChat(OpenAIServing):
model_name = request.model model_name = request.model
created_time = int(time.monotonic()) created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk" chunk_object_type = "chat.completion.chunk"
first_iteration = True
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
try:
async for res in result_generator:
res: RequestOutput
# We need to do it here, because if there are exceptions in
# the result_generator, it needs to be sent as the FIRST
# response (by the try...catch).
if first_iteration:
# Send first response for each request.n (index) with the role # Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request) role = self.get_chat_request_role(request)
for i in range(request.n): for i in range(request.n):
@ -108,7 +124,8 @@ class OpenAIServingChat(OpenAIServing):
delta=DeltaMessage(role=role), delta=DeltaMessage(role=role),
logprobs=None, logprobs=None,
finish_reason=None) finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id, chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type, object=chunk_object_type,
created=created_time, created=created_time,
choices=[choice_data], choices=[choice_data],
@ -120,7 +137,8 @@ class OpenAIServingChat(OpenAIServing):
if request.echo: if request.echo:
last_msg_content = "" last_msg_content = ""
if request.messages and isinstance( if request.messages and isinstance(
request.messages, list) and request.messages[-1].get( request.messages,
list) and request.messages[-1].get(
"content") and request.messages[-1].get( "content") and request.messages[-1].get(
"role") == role: "role") == role:
last_msg_content = request.messages[-1]["content"] last_msg_content = request.messages[-1]["content"]
@ -129,7 +147,8 @@ class OpenAIServingChat(OpenAIServing):
for i in range(request.n): for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
delta=DeltaMessage(content=last_msg_content), delta=DeltaMessage(
content=last_msg_content),
finish_reason=None) finish_reason=None)
chunk = ChatCompletionStreamResponse( chunk = ChatCompletionStreamResponse(
id=request_id, id=request_id,
@ -138,15 +157,11 @@ class OpenAIServingChat(OpenAIServing):
choices=[choice_data], choices=[choice_data],
logprobs=None, logprobs=None,
model=model_name) model=model_name)
data = chunk.model_dump_json(exclude_unset=True) data = chunk.model_dump_json(
exclude_unset=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
first_iteration = False
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs: for output in res.outputs:
i = output.index i = output.index
@ -191,7 +206,8 @@ class OpenAIServingChat(OpenAIServing):
final_usage = UsageInfo( final_usage = UsageInfo(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i], completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i], total_tokens=prompt_tokens +
previous_num_tokens[i],
) )
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=i, index=i,
@ -210,6 +226,10 @@ class OpenAIServingChat(OpenAIServing):
exclude_none=True) exclude_none=True)
yield f"data: {data}\n\n" yield f"data: {data}\n\n"
finish_reason_sent[i] = True finish_reason_sent[i] = True
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
yield f"data: {data}\n\n"
# Send the final done message after all response.n are finished # Send the final done message after all response.n are finished
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"

View File

@ -26,107 +26,6 @@ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator(
request: CompletionRequest,
raw_request: Request,
on_abort,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json()
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).model_dump_json()
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> Tuple[bool, list]: def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following # get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays." # "a string, array of strings, array of tokens, or array of token arrays."
@ -151,73 +50,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts return prompt_is_tokens, prompts
def request_output_to_completion_response(
final_res_batch: List[RequestOutput],
request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
def merge_async_iterators(*iterators): def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator. """Merge multiple asynchronous iterators into a single iterator.
@ -230,8 +62,11 @@ def merge_async_iterators(*iterators):
finished = [False] * len(iterators) finished = [False] * len(iterators)
async def producer(i, iterator): async def producer(i, iterator):
try:
async for item in iterator: async for item in iterator:
await queue.put((i, item)) await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True finished[i] = True
_tasks = [ _tasks = [
@ -242,6 +77,8 @@ def merge_async_iterators(*iterators):
async def consumer(): async def consumer():
while not all(finished) or not queue.empty(): while not all(finished) or not queue.empty():
item = await queue.get() item = await queue.get()
if isinstance(item, Exception):
raise item
yield item yield item
await asyncio.gather(*_tasks) await asyncio.gather(*_tasks)
@ -312,6 +149,7 @@ class OpenAIServingCompletion(OpenAIServing):
prompt_token_ids=input_ids, prompt_token_ids=input_ids,
lora_request=lora_request)) lora_request=lora_request))
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[ result_generator: AsyncIterator[Tuple[
@ -325,11 +163,9 @@ class OpenAIServingCompletion(OpenAIServing):
# Streaming response # Streaming response
if stream: if stream:
return completion_stream_generator(request, return self.completion_stream_generator(request,
raw_request, raw_request,
self.engine.abort,
result_generator, result_generator,
self._create_logprobs,
request_id, request_id,
created_time, created_time,
model_name, model_name,
@ -337,15 +173,18 @@ class OpenAIServingCompletion(OpenAIServing):
# Non-streaming response # Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts) final_res_batch: RequestOutput = [None] * len(prompts)
try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}") await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected") return self.create_error_response("Client disconnected")
final_res_batch[i] = res final_res_batch[i] = res
response = request_output_to_completion_response( response = self.request_output_to_completion_response(
final_res_batch, request, self._create_logprobs, request_id, final_res_batch, request, request_id, created_time, model_name)
created_time, model_name) except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
# When user requests streaming but we don't stream, we still need to # When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event. # return a streaming response with a single event.
@ -359,3 +198,179 @@ class OpenAIServingCompletion(OpenAIServing):
return fake_stream_generator() return fake_stream_generator()
return response return response
async def completion_stream_generator(
self,
request: CompletionRequest,
raw_request: Request,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
try:
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await self.engine.abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[
i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs
or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[
previous_num_tokens[i]:]
top_logprobs = output.logprobs[previous_num_tokens[
i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json()
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
logprobs = LogProbs(
) if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).model_dump_json()
yield f"data: {response_json}\n\n"
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
data = self.create_streaming_error_response(str(e))
print("yield", f"data: {data}\n\n")
yield f"data: {data}\n\n"
print("yield", "data: [DONE]\n\n")
yield "data: [DONE]\n\n"
def request_output_to_completion_response(
self,
final_res_batch: List[RequestOutput],
request: CompletionRequest,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
@ -11,6 +12,7 @@ from vllm.entrypoints.openai.protocol import (CompletionRequest,
ModelCard, ModelList, ModelCard, ModelList,
ModelPermission) ModelPermission)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
logger = init_logger(__name__) logger = init_logger(__name__)
@ -83,7 +85,7 @@ class OpenAIServing:
def _create_logprobs( def _create_logprobs(
self, self,
token_ids: List[int], token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, top_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None,
num_output_top_logprobs: Optional[int] = None, num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0, initial_text_offset: int = 0,
) -> LogProbs: ) -> LogProbs:
@ -95,10 +97,10 @@ class OpenAIServing:
for i, token_id in enumerate(token_ids): for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i] step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None: if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id] token_logprob = step_top_logprobs[token_id].logprob
else: else:
token_logprob = None token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id) token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token) logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob) logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0: if len(logprobs.text_offset) == 0:
@ -110,7 +112,7 @@ class OpenAIServing:
if num_output_top_logprobs: if num_output_top_logprobs:
logprobs.top_logprobs.append({ logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p p.decoded_token: p.logprob
for i, p in step_top_logprobs.items() for i, p in step_top_logprobs.items()
} if step_top_logprobs else None) } if step_top_logprobs else None)
return logprobs return logprobs
@ -124,6 +126,19 @@ class OpenAIServing:
type=err_type, type=err_type,
code=status_code.value) code=status_code.value)
def create_streaming_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str:
json_str = json.dumps({
"error":
self.create_error_response(message=message,
err_type=err_type,
status_code=status_code).model_dump()
})
return json_str
async def _check_model(self, request) -> Optional[ErrorResponse]: async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model: if request.model == self.served_model:
return return

View File

@ -8,8 +8,9 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SequenceData, SequenceGroupOutput, SequenceOutput) SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm.utils import is_neuron from vllm.utils import is_neuron
@ -528,7 +529,10 @@ def _get_logprobs(
prompt_logprobs_dict.update( prompt_logprobs_dict.update(
zip(top_token_ids[sample_idx, :num_logprobs].tolist(), zip(top_token_ids[sample_idx, :num_logprobs].tolist(),
top_logprobs[sample_idx, :num_logprobs].tolist())) top_logprobs[sample_idx, :num_logprobs].tolist()))
group_prompt_logprobs.append(prompt_logprobs_dict) group_prompt_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in prompt_logprobs_dict.items()
})
sample_idx += 1 sample_idx += 1
query_result_idx += 1 query_result_idx += 1
result_prompt_logprobs.append(group_prompt_logprobs) result_prompt_logprobs.append(group_prompt_logprobs)
@ -553,7 +557,10 @@ def _get_logprobs(
parent_id, :num_logprobs].tolist(), parent_id, :num_logprobs].tolist(),
top_logprobs[sample_idx + top_logprobs[sample_idx +
parent_id, :num_logprobs].tolist())) parent_id, :num_logprobs].tolist()))
group_sample_logprobs.append(sample_logprobs_dict) group_sample_logprobs.append({
token_id: Logprob(logprob)
for token_id, logprob in sample_logprobs_dict.items()
})
result_sample_logprobs.append(group_sample_logprobs) result_sample_logprobs.append(group_sample_logprobs)
sample_idx += len(seq_ids) sample_idx += len(seq_ids)

View File

@ -8,8 +8,16 @@ from vllm.block import LogicalTokenBlock
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
PromptLogprobs = List[Optional[Dict[int, float]]]
SampleLogprobs = List[Dict[int, float]] @dataclass
class Logprob:
"""Infos for supporting OpenAI compatible logprobs."""
logprob: float
decoded_token: Optional[str] = None
PromptLogprobs = List[Optional[Dict[int, Logprob]]]
SampleLogprobs = List[Dict[int, Logprob]]
class SequenceStatus(enum.Enum): class SequenceStatus(enum.Enum):
@ -196,12 +204,12 @@ class Sequence:
def append_token_id( def append_token_id(
self, self,
token_id: int, token_id: int,
logprobs: Dict[int, float], logprobs: Dict[int, Logprob],
) -> None: ) -> None:
assert token_id in logprobs assert token_id in logprobs
self._append_tokens_to_blocks([token_id]) self._append_tokens_to_blocks([token_id])
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id]) self.data.append_token_id(token_id, logprobs[token_id].logprob)
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
@ -456,7 +464,7 @@ class SequenceOutput:
self, self,
parent_seq_id: int, parent_seq_id: int,
output_token: int, output_token: int,
logprobs: Dict[int, float], logprobs: Dict[int, Logprob],
) -> None: ) -> None:
self.parent_seq_id = parent_seq_id self.parent_seq_id = parent_seq_id
self.output_token = output_token self.output_token = output_token
@ -470,9 +478,10 @@ class SequenceOutput:
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutput): if not isinstance(other, SequenceOutput):
raise NotImplementedError() raise NotImplementedError()
return (self.parent_seq_id == other.parent_seq_id equal = (self.parent_seq_id == other.parent_seq_id
and self.output_token == other.output_token and self.output_token == other.output_token)
and self.logprobs == other.logprobs) log_probs_equal = other.logprobs == self.logprobs
return equal and log_probs_equal
class SequenceGroupOutput: class SequenceGroupOutput:

View File

@ -77,7 +77,7 @@ class MultiStepWorker(Worker):
token_id = seq_output.output_token token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id] token_logprob = seq_output.logprobs[token_id]
seq.append_token_id(token_id, token_logprob) seq.append_token_id(token_id, token_logprob.logprob)
def _shallow_copy_inputs( def _shallow_copy_inputs(
self, seq_group_metadata_list: List[SequenceGroupMetadata] self, seq_group_metadata_list: List[SequenceGroupMetadata]