[V1] [5/N] API Server: unify Detokenizer and EngineCore input (#11545)

Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-12-28 15:51:57 -05:00 committed by GitHub
parent 328841d002
commit 4fb8e329fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 66 additions and 77 deletions

View File

@ -3,9 +3,9 @@ from typing import List
import pytest import pytest
from transformers import AutoTokenizer from transformers import AutoTokenizer
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.engine import EngineCoreOutput from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest from vllm.v1.engine.detokenizer import Detokenizer
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3" TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
@ -71,16 +71,22 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
# Make N requests. # Make N requests.
requests = [ requests = [
DetokenizerRequest( EngineCoreRequest(request_id=f"request-{idx}",
request_id=f"request-{idx}",
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
output_kind=request_output_kind, output_kind=request_output_kind,
stop=[], stop=[],
include_stop_str_in_output=False, include_stop_str_in_output=False))
) for idx, ( for idx, (
prompt, prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
] ]
@ -133,16 +139,23 @@ def test_stop_string(include_stop_str_in_output: bool):
# Make N requests. # Make N requests.
requests = [ requests = [
DetokenizerRequest( EngineCoreRequest(
request_id=f"request-{idx}", request_id=f"request-{idx}",
prompt=prompt, prompt=prompt,
prompt_token_ids=prompt_tokens, prompt_token_ids=prompt_tokens,
arrival_time=0,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
eos_token_id=None,
lora_request=None,
sampling_params=SamplingParams(
skip_special_tokens=False, skip_special_tokens=False,
spaces_between_special_tokens=False, spaces_between_special_tokens=False,
output_kind=RequestOutputKind.DELTA, output_kind=RequestOutputKind.DELTA,
stop=STOP_STRINGS, stop=STOP_STRINGS,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
) for idx, ( )) for idx, (
prompt, prompt,
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS)) prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
] ]

View File

@ -6,21 +6,7 @@ import msgspec
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import SamplingParams
@dataclass
class DetokenizerRequest:
request_id: str
prompt: Optional[str]
prompt_token_ids: List[int]
skip_special_tokens: bool
spaces_between_special_tokens: bool
output_kind: RequestOutputKind
stop: List[str]
include_stop_str_in_output: bool
@dataclass @dataclass

View File

@ -158,16 +158,18 @@ class AsyncLLM(EngineClient):
raise ValueError(f"Request id {request_id} already running.") raise ValueError(f"Request id {request_id} already running.")
self.rid_to_queue[request_id] = asyncio.Queue() self.rid_to_queue[request_id] = asyncio.Queue()
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest. # 2) Convert Input --> Request.
detokenizer_req, engine_core_req = self.processor.process_inputs( request = self.processor.process_inputs(request_id, prompt, params,
request_id, prompt, params, arrival_time, lora_request, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority) trace_headers,
prompt_adapter_request,
priority)
# 3) Add the request to Detokenizer (this process). # 3) Add the request to Detokenizer (this process).
self.detokenizer.add_request(detokenizer_req) self.detokenizer.add_request(request)
# 4) Add the EngineCoreRequest to EngineCore (separate process). # 4) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(engine_core_req) await self.engine_core.add_request_async(request)
if self.log_requests: if self.log_requests:
logger.info("Added request %s.", request_id) logger.info("Added request %s.", request_id)

View File

@ -8,7 +8,7 @@ from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.detokenizer_utils import ( from vllm.transformers_utils.detokenizer_utils import (
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
logger = init_logger(__name__) logger = init_logger(__name__)
@ -55,19 +55,19 @@ class IncrementalDetokenizer:
def from_new_request( def from_new_request(
cls, cls,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
request: DetokenizerRequest, request: EngineCoreRequest,
) -> "IncrementalDetokenizer": ) -> "IncrementalDetokenizer":
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens( tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
tokenizer=tokenizer, tokenizer=tokenizer,
prompt_ids=request.prompt_token_ids, prompt_ids=request.prompt_token_ids,
skip_special_tokens=request.skip_special_tokens, skip_special_tokens=request.sampling_params.skip_special_tokens,
) )
stops = request.stop stops = request.sampling_params.stop
# Number of chars to hold back when stop strings are to be excluded # Number of chars to hold back when stop strings are to be excluded
# from streamed output. # from streamed output.
if stops and not request.include_stop_str_in_output: if stops and not request.sampling_params.include_stop_str_in_output:
stop_buffer_length = max(len(s) for s in stops) - 1 stop_buffer_length = max(len(s) for s in stops) - 1
else: else:
stop_buffer_length = 0 stop_buffer_length = 0
@ -79,13 +79,14 @@ class IncrementalDetokenizer:
# NOTE(Nick): could we take ownership of it though? # NOTE(Nick): could we take ownership of it though?
token_ids=request.prompt_token_ids.copy(), token_ids=request.prompt_token_ids.copy(),
stop=stops, stop=stops,
include_stop_str_in_output=request.include_stop_str_in_output, include_stop_str_in_output=request.sampling_params.
include_stop_str_in_output,
prefix_offset=prefix_offset, prefix_offset=prefix_offset,
read_offset=read_offset, read_offset=read_offset,
skip_special_tokens=request.skip_special_tokens, skip_special_tokens=request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=request. spaces_between_special_tokens=request.sampling_params.
spaces_between_special_tokens, spaces_between_special_tokens,
output_kind=request.output_kind, output_kind=request.sampling_params.output_kind,
request_id=request.request_id, request_id=request.request_id,
prompt=request.prompt, prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
@ -227,7 +228,7 @@ class Detokenizer:
def add_request( def add_request(
self, self,
request: DetokenizerRequest, request: EngineCoreRequest,
): ):
"""Add new request to the Detokenizer.""" """Add new request to the Detokenizer."""

View File

@ -152,15 +152,17 @@ class LLMEngine:
) -> None: ) -> None:
# 1) Process raw inputs into the request. # 1) Process raw inputs into the request.
detokenizer_req, engine_core_req = self.processor.process_inputs( request = self.processor.process_inputs(request_id, prompt, params,
request_id, prompt, params, arrival_time, lora_request, arrival_time, lora_request,
trace_headers, prompt_adapter_request, priority) trace_headers,
prompt_adapter_request,
priority)
# 2) Add the request to Detokenizer. # 2) Add the request to Detokenizer.
self.detokenizer.add_request(detokenizer_req) self.detokenizer.add_request(request)
# 3) Add the request to EngineCore. # 3) Add the request to EngineCore.
self.engine_core.add_request(engine_core_req) self.engine_core.add_request(request)
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:

View File

@ -1,5 +1,5 @@
import time import time
from typing import Mapping, Optional, Tuple, Union from typing import Mapping, Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
@ -13,7 +13,7 @@ from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
@ -62,7 +62,7 @@ class Processor:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> Tuple[DetokenizerRequest, EngineCoreRequest]: ) -> EngineCoreRequest:
# TODO(woosuk): Support pooling models. # TODO(woosuk): Support pooling models.
# TODO(woosuk): Check max_logprobs # TODO(woosuk): Check max_logprobs
@ -123,20 +123,7 @@ class Processor:
decoder_inputs.multi_modal_data, mm_hashes, decoder_inputs.multi_modal_data, mm_hashes,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
# Make Request for Detokenizer. return EngineCoreRequest(
detokenizer_request = DetokenizerRequest(
request_id,
decoder_inputs.prompt,
decoder_inputs.prompt_token_ids,
sampling_params.skip_special_tokens,
sampling_params.spaces_between_special_tokens,
sampling_params.output_kind,
sampling_params.stop,
sampling_params.include_stop_str_in_output,
)
# Make Request for EngineCore.
engine_core_request = EngineCoreRequest(
request_id, request_id,
decoder_inputs.prompt, decoder_inputs.prompt,
decoder_inputs.prompt_token_ids, decoder_inputs.prompt_token_ids,
@ -149,8 +136,6 @@ class Processor:
lora_request, lora_request,
) )
return detokenizer_request, engine_core_request
def _validate_model_inputs(self, inputs: ProcessorInputs): def _validate_model_inputs(self, inputs: ProcessorInputs):
if is_encoder_decoder_inputs(inputs): if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len # For encoder-decoder multimodal models, the max_prompt_len