mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 01:55:36 +08:00
[V1] [5/N] API Server: unify Detokenizer and EngineCore input (#11545)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
This commit is contained in:
parent
328841d002
commit
4fb8e329fd
@ -3,9 +3,9 @@ from typing import List
|
|||||||
import pytest
|
import pytest
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.sampling_params import RequestOutputKind
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||||
from vllm.v1.engine import EngineCoreOutput
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer, DetokenizerRequest
|
from vllm.v1.engine.detokenizer import Detokenizer
|
||||||
|
|
||||||
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
TOKENIZER_NAME = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
||||||
@ -71,16 +71,22 @@ def test_incremental_detokenization(request_output_kind: RequestOutputKind):
|
|||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
requests = [
|
requests = [
|
||||||
DetokenizerRequest(
|
EngineCoreRequest(request_id=f"request-{idx}",
|
||||||
request_id=f"request-{idx}",
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
|
arrival_time=0,
|
||||||
|
mm_inputs=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
mm_placeholders=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
lora_request=None,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
output_kind=request_output_kind,
|
output_kind=request_output_kind,
|
||||||
stop=[],
|
stop=[],
|
||||||
include_stop_str_in_output=False,
|
include_stop_str_in_output=False))
|
||||||
) for idx, (
|
for idx, (
|
||||||
prompt,
|
prompt,
|
||||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||||
]
|
]
|
||||||
@ -133,16 +139,23 @@ def test_stop_string(include_stop_str_in_output: bool):
|
|||||||
|
|
||||||
# Make N requests.
|
# Make N requests.
|
||||||
requests = [
|
requests = [
|
||||||
DetokenizerRequest(
|
EngineCoreRequest(
|
||||||
request_id=f"request-{idx}",
|
request_id=f"request-{idx}",
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
prompt_token_ids=prompt_tokens,
|
prompt_token_ids=prompt_tokens,
|
||||||
|
arrival_time=0,
|
||||||
|
mm_inputs=None,
|
||||||
|
mm_hashes=None,
|
||||||
|
mm_placeholders=None,
|
||||||
|
eos_token_id=None,
|
||||||
|
lora_request=None,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
skip_special_tokens=False,
|
skip_special_tokens=False,
|
||||||
spaces_between_special_tokens=False,
|
spaces_between_special_tokens=False,
|
||||||
output_kind=RequestOutputKind.DELTA,
|
output_kind=RequestOutputKind.DELTA,
|
||||||
stop=STOP_STRINGS,
|
stop=STOP_STRINGS,
|
||||||
include_stop_str_in_output=include_stop_str_in_output,
|
include_stop_str_in_output=include_stop_str_in_output,
|
||||||
) for idx, (
|
)) for idx, (
|
||||||
prompt,
|
prompt,
|
||||||
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
prompt_tokens) in enumerate(zip(PROMPT_STRINGS, PROMPT_TOKENS))
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,21 +6,7 @@ import msgspec
|
|||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
from vllm.multimodal import MultiModalKwargs, MultiModalPlaceholderDict
|
||||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DetokenizerRequest:
|
|
||||||
|
|
||||||
request_id: str
|
|
||||||
prompt: Optional[str]
|
|
||||||
prompt_token_ids: List[int]
|
|
||||||
skip_special_tokens: bool
|
|
||||||
spaces_between_special_tokens: bool
|
|
||||||
output_kind: RequestOutputKind
|
|
||||||
|
|
||||||
stop: List[str]
|
|
||||||
include_stop_str_in_output: bool
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -158,16 +158,18 @@ class AsyncLLM(EngineClient):
|
|||||||
raise ValueError(f"Request id {request_id} already running.")
|
raise ValueError(f"Request id {request_id} already running.")
|
||||||
self.rid_to_queue[request_id] = asyncio.Queue()
|
self.rid_to_queue[request_id] = asyncio.Queue()
|
||||||
|
|
||||||
# 2) Convert input --> DetokenizerRequest / EngineCoreRequest.
|
# 2) Convert Input --> Request.
|
||||||
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
request = self.processor.process_inputs(request_id, prompt, params,
|
||||||
request_id, prompt, params, arrival_time, lora_request,
|
arrival_time, lora_request,
|
||||||
trace_headers, prompt_adapter_request, priority)
|
trace_headers,
|
||||||
|
prompt_adapter_request,
|
||||||
|
priority)
|
||||||
|
|
||||||
# 3) Add the request to Detokenizer (this process).
|
# 3) Add the request to Detokenizer (this process).
|
||||||
self.detokenizer.add_request(detokenizer_req)
|
self.detokenizer.add_request(request)
|
||||||
|
|
||||||
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
# 4) Add the EngineCoreRequest to EngineCore (separate process).
|
||||||
await self.engine_core.add_request_async(engine_core_req)
|
await self.engine_core.add_request_async(request)
|
||||||
|
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info("Added request %s.", request_id)
|
logger.info("Added request %s.", request_id)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from vllm.sampling_params import RequestOutputKind
|
|||||||
from vllm.transformers_utils.detokenizer_utils import (
|
from vllm.transformers_utils.detokenizer_utils import (
|
||||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput
|
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -55,19 +55,19 @@ class IncrementalDetokenizer:
|
|||||||
def from_new_request(
|
def from_new_request(
|
||||||
cls,
|
cls,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
request: DetokenizerRequest,
|
request: EngineCoreRequest,
|
||||||
) -> "IncrementalDetokenizer":
|
) -> "IncrementalDetokenizer":
|
||||||
|
|
||||||
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
tokens, prefix_offset, read_offset = convert_prompt_ids_to_tokens(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
prompt_ids=request.prompt_token_ids,
|
prompt_ids=request.prompt_token_ids,
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
stops = request.stop
|
stops = request.sampling_params.stop
|
||||||
# Number of chars to hold back when stop strings are to be excluded
|
# Number of chars to hold back when stop strings are to be excluded
|
||||||
# from streamed output.
|
# from streamed output.
|
||||||
if stops and not request.include_stop_str_in_output:
|
if stops and not request.sampling_params.include_stop_str_in_output:
|
||||||
stop_buffer_length = max(len(s) for s in stops) - 1
|
stop_buffer_length = max(len(s) for s in stops) - 1
|
||||||
else:
|
else:
|
||||||
stop_buffer_length = 0
|
stop_buffer_length = 0
|
||||||
@ -79,13 +79,14 @@ class IncrementalDetokenizer:
|
|||||||
# NOTE(Nick): could we take ownership of it though?
|
# NOTE(Nick): could we take ownership of it though?
|
||||||
token_ids=request.prompt_token_ids.copy(),
|
token_ids=request.prompt_token_ids.copy(),
|
||||||
stop=stops,
|
stop=stops,
|
||||||
include_stop_str_in_output=request.include_stop_str_in_output,
|
include_stop_str_in_output=request.sampling_params.
|
||||||
|
include_stop_str_in_output,
|
||||||
prefix_offset=prefix_offset,
|
prefix_offset=prefix_offset,
|
||||||
read_offset=read_offset,
|
read_offset=read_offset,
|
||||||
skip_special_tokens=request.skip_special_tokens,
|
skip_special_tokens=request.sampling_params.skip_special_tokens,
|
||||||
spaces_between_special_tokens=request.
|
spaces_between_special_tokens=request.sampling_params.
|
||||||
spaces_between_special_tokens,
|
spaces_between_special_tokens,
|
||||||
output_kind=request.output_kind,
|
output_kind=request.sampling_params.output_kind,
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
prompt=request.prompt,
|
prompt=request.prompt,
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
@ -227,7 +228,7 @@ class Detokenizer:
|
|||||||
|
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request: DetokenizerRequest,
|
request: EngineCoreRequest,
|
||||||
):
|
):
|
||||||
"""Add new request to the Detokenizer."""
|
"""Add new request to the Detokenizer."""
|
||||||
|
|
||||||
|
|||||||
@ -152,15 +152,17 @@ class LLMEngine:
|
|||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
# 1) Process raw inputs into the request.
|
# 1) Process raw inputs into the request.
|
||||||
detokenizer_req, engine_core_req = self.processor.process_inputs(
|
request = self.processor.process_inputs(request_id, prompt, params,
|
||||||
request_id, prompt, params, arrival_time, lora_request,
|
arrival_time, lora_request,
|
||||||
trace_headers, prompt_adapter_request, priority)
|
trace_headers,
|
||||||
|
prompt_adapter_request,
|
||||||
|
priority)
|
||||||
|
|
||||||
# 2) Add the request to Detokenizer.
|
# 2) Add the request to Detokenizer.
|
||||||
self.detokenizer.add_request(detokenizer_req)
|
self.detokenizer.add_request(request)
|
||||||
|
|
||||||
# 3) Add the request to EngineCore.
|
# 3) Add the request to EngineCore.
|
||||||
self.engine_core.add_request(engine_core_req)
|
self.engine_core.add_request(request)
|
||||||
|
|
||||||
def step(self) -> List[RequestOutput]:
|
def step(self) -> List[RequestOutput]:
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Mapping, Optional, Tuple, Union
|
from typing import Mapping, Optional, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
@ -13,7 +13,7 @@ from vllm.pooling_params import PoolingParams
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ class Processor:
|
|||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
) -> Tuple[DetokenizerRequest, EngineCoreRequest]:
|
) -> EngineCoreRequest:
|
||||||
|
|
||||||
# TODO(woosuk): Support pooling models.
|
# TODO(woosuk): Support pooling models.
|
||||||
# TODO(woosuk): Check max_logprobs
|
# TODO(woosuk): Check max_logprobs
|
||||||
@ -123,20 +123,7 @@ class Processor:
|
|||||||
decoder_inputs.multi_modal_data, mm_hashes,
|
decoder_inputs.multi_modal_data, mm_hashes,
|
||||||
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
||||||
|
|
||||||
# Make Request for Detokenizer.
|
return EngineCoreRequest(
|
||||||
detokenizer_request = DetokenizerRequest(
|
|
||||||
request_id,
|
|
||||||
decoder_inputs.prompt,
|
|
||||||
decoder_inputs.prompt_token_ids,
|
|
||||||
sampling_params.skip_special_tokens,
|
|
||||||
sampling_params.spaces_between_special_tokens,
|
|
||||||
sampling_params.output_kind,
|
|
||||||
sampling_params.stop,
|
|
||||||
sampling_params.include_stop_str_in_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Make Request for EngineCore.
|
|
||||||
engine_core_request = EngineCoreRequest(
|
|
||||||
request_id,
|
request_id,
|
||||||
decoder_inputs.prompt,
|
decoder_inputs.prompt,
|
||||||
decoder_inputs.prompt_token_ids,
|
decoder_inputs.prompt_token_ids,
|
||||||
@ -149,8 +136,6 @@ class Processor:
|
|||||||
lora_request,
|
lora_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
return detokenizer_request, engine_core_request
|
|
||||||
|
|
||||||
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
def _validate_model_inputs(self, inputs: ProcessorInputs):
|
||||||
if is_encoder_decoder_inputs(inputs):
|
if is_encoder_decoder_inputs(inputs):
|
||||||
# For encoder-decoder multimodal models, the max_prompt_len
|
# For encoder-decoder multimodal models, the max_prompt_len
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user