mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 14:27:19 +08:00
[CORE] Prompt Embeddings Support for v1 Engine (#24278)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Andrew Sansom <qthequartermasterman@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
parent
9fac6aa30b
commit
9a4600e4dc
@ -76,11 +76,6 @@ def test_models(
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
if async_scheduling:
|
||||
pytest.skip("async_scheduling only supported in v1.")
|
||||
@ -164,11 +159,6 @@ def test_models_distributed(
|
||||
extra_env: dict[str, str],
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if test_suite != TARGET_TEST_SUITE:
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
|
||||
@ -36,7 +36,6 @@ def default_server_args() -> list[str]:
|
||||
"--enforce-eager",
|
||||
# Prompt Embeds server args
|
||||
"--enable-prompt-embeds",
|
||||
"--no-enable-chunked-prefill",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
# in parts of the operators
|
||||
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
|
||||
|
||||
# Note: can be removed when
|
||||
# https://github.com/vllm-project/vllm/pull/24278 finished
|
||||
if current_platform.is_cpu() and use_prompt_embeds:
|
||||
pytest.skip("Skipping use_prompt_embeds=True with "
|
||||
"V1-only CPU backend.")
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
@ -1513,12 +1513,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No text embedding inputs so far.
|
||||
if self.enable_prompt_embeds:
|
||||
_raise_or_fallback(feature_name="--enable-prompt-embeds",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
@ -1651,6 +1645,13 @@ class EngineArgs:
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prompt_embeds:
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V0. Prefix caching has "
|
||||
"been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256
|
||||
@ -1664,6 +1665,17 @@ class EngineArgs:
|
||||
# For pooling tasks the default is False
|
||||
if model_config.runner_type != "pooling":
|
||||
self.enable_chunked_prefill = True
|
||||
|
||||
# TODO: When prefix caching supports prompt embeds inputs, this
|
||||
# check can be removed.
|
||||
if (self.enable_prompt_embeds
|
||||
and self.enable_prefix_caching is not False):
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V1. Prefix caching has "
|
||||
"been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
else:
|
||||
|
||||
@ -973,7 +973,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: Optional[str] = None
|
||||
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
@ -1009,6 +1008,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:completion-sampling-params]
|
||||
|
||||
# --8<-- [start:completion-extra-params]
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
|
||||
@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None:
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, process_name, pid)
|
||||
_add_prefix(sys.stderr, process_name, pid)
|
||||
|
||||
|
||||
def length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_embeds: Optional[torch.Tensor],
|
||||
) -> int:
|
||||
"""Calculate the request length (in number of tokens) give either
|
||||
prompt_token_ids or prompt_embeds.
|
||||
"""
|
||||
prompt_token_len = None if prompt_token_ids is None else len(
|
||||
prompt_token_ids)
|
||||
prompt_embeds_len = \
|
||||
None if prompt_embeds is None else len(prompt_embeds)
|
||||
|
||||
if prompt_token_len is None:
|
||||
if prompt_embeds_len is None:
|
||||
raise ValueError(
|
||||
"Neither prompt_token_ids nor prompt_embeds were defined.")
|
||||
return prompt_embeds_len
|
||||
else:
|
||||
if (prompt_embeds_len is not None
|
||||
and prompt_embeds_len != prompt_token_len):
|
||||
raise ValueError(
|
||||
"Prompt token ids and prompt embeds had different lengths"
|
||||
f" prompt_token_ids={prompt_token_len}"
|
||||
f" prompt_embeds={prompt_embeds_len}")
|
||||
return prompt_token_len
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm._bc_linter import bc_linter_include
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
@ -26,13 +27,14 @@ if TYPE_CHECKING:
|
||||
class NewRequestData:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
@ -49,9 +51,12 @@ class NewRequestData:
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
prompt_embeds_shape = (self.prompt_embeds.shape
|
||||
if self.prompt_embeds else None)
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
@ -59,19 +64,26 @@ class NewRequestData:
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")")
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self):
|
||||
def anon_repr(self) -> str:
|
||||
prompt_token_ids_len = len(
|
||||
self.prompt_token_ids
|
||||
) if self.prompt_token_ids is not None else None
|
||||
prompt_embeds_shape = (self.prompt_embeds.shape
|
||||
if self.prompt_embeds else None)
|
||||
return (f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
|
||||
f"prompt_token_ids_len={prompt_token_ids_len},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request}"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")")
|
||||
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ class EngineCoreRequest(
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
@ -56,6 +56,7 @@ class EngineCoreRequest(
|
||||
lora_request: Optional[LoRARequest]
|
||||
cache_salt: Optional[str]
|
||||
data_parallel_rank: Optional[int]
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.detokenizer_utils import (
|
||||
AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally)
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -179,11 +180,12 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
||||
|
||||
# Find a safe place to start.
|
||||
prompt_suffix = request.prompt_token_ids
|
||||
prompt_token_ids = request.prompt_token_ids or []
|
||||
prompt_suffix = prompt_token_ids
|
||||
prompt_len = len(prompt_suffix)
|
||||
if prompt_len > 4:
|
||||
for i in range(4, min(prompt_len + 1, 24)):
|
||||
suffix = request.prompt_token_ids[-i:]
|
||||
suffix = prompt_token_ids[-i:]
|
||||
if '<EFBFBD>' not in self.tokenizer.decode(suffix):
|
||||
prompt_suffix = suffix
|
||||
break
|
||||
@ -260,16 +262,25 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
||||
params = request.sampling_params
|
||||
assert params is not None
|
||||
|
||||
# Metadata for incremental detokenization.
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
))
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids)
|
||||
self.prompt_len = len(request.prompt_token_ids)
|
||||
# Metadata for incremental detokenization.
|
||||
if request.prompt_token_ids is not None:
|
||||
self.tokens, self.prefix_offset, self.read_offset = (
|
||||
convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=request.prompt_token_ids,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
))
|
||||
else:
|
||||
# Prompt embedding requests cannot be detokenized, in general.
|
||||
self.tokens = [""] * self.prompt_len
|
||||
self.prefix_offset = 0
|
||||
self.read_offest = 0
|
||||
|
||||
self.token_ids.extend(request.prompt_token_ids
|
||||
or [0] * self.prompt_len)
|
||||
|
||||
self.skip_special_tokens = params.skip_special_tokens
|
||||
self.spaces_between_special_tokens = (
|
||||
|
||||
@ -14,6 +14,7 @@ from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tracing import (SpanAttributes, SpanKind, Tracer,
|
||||
extract_trace_context)
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
|
||||
from vllm.v1.engine.logprobs import LogprobsProcessor
|
||||
@ -86,7 +87,8 @@ class RequestState:
|
||||
lora_name: Optional[str],
|
||||
output_kind: RequestOutputKind,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: list[int],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
prompt_embeds: Optional[torch.Tensor],
|
||||
logprobs_processor: Optional[LogprobsProcessor],
|
||||
detokenizer: Optional[IncrementalDetokenizer],
|
||||
max_tokens_param: Optional[int],
|
||||
@ -104,7 +106,9 @@ class RequestState:
|
||||
self.output_kind = output_kind
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_len = len(prompt_token_ids)
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
self.logprobs_processor = logprobs_processor
|
||||
self.detokenizer = detokenizer
|
||||
self.max_tokens_param = max_tokens_param
|
||||
@ -165,6 +169,7 @@ class RequestState:
|
||||
output_kind=output_kind,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
logprobs_processor=logprobs_processor,
|
||||
detokenizer=detokenizer,
|
||||
max_tokens_param=max_tokens_param,
|
||||
@ -223,6 +228,8 @@ class RequestState:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, PoolingOutput):
|
||||
assert len(outputs) == 1
|
||||
# Prompt embeddings are currently not supported by pooling requests.
|
||||
assert self.prompt_token_ids is not None
|
||||
return PoolingRequestOutput(
|
||||
request_id=request_id,
|
||||
outputs=first_output,
|
||||
@ -236,10 +243,15 @@ class RequestState:
|
||||
else:
|
||||
prompt_logprobs = self.logprobs_processor.prompt_logprobs
|
||||
|
||||
# If prompt embeds were used, put placeholder prompt token ids
|
||||
prompt_token_ids = self.prompt_token_ids
|
||||
if prompt_token_ids is None and self.prompt_embeds is not None:
|
||||
prompt_token_ids = [0] * len(self.prompt_embeds)
|
||||
|
||||
return RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=self.prompt,
|
||||
prompt_token_ids=self.prompt_token_ids,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
outputs=cast(list[CompletionOutput], outputs),
|
||||
finished=finished,
|
||||
@ -469,6 +481,8 @@ class OutputProcessor:
|
||||
|
||||
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
|
||||
trace_context = extract_trace_context(engine_core_output.trace_headers)
|
||||
prompt_length = length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds)
|
||||
with (self.tracer.start_as_current_span(
|
||||
"llm_request",
|
||||
kind=SpanKind.SERVER,
|
||||
@ -488,7 +502,7 @@ class OutputProcessor:
|
||||
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
|
||||
queued_time)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
|
||||
len(req_state.prompt_token_ids))
|
||||
prompt_length)
|
||||
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
|
||||
metrics.num_generation_tokens)
|
||||
span.set_attribute(
|
||||
@ -544,7 +558,8 @@ class OutputProcessor:
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason=finish_reason,
|
||||
num_prompt_tokens=len(req_state.prompt_token_ids),
|
||||
num_prompt_tokens=length_from_prompt_token_ids_or_embeds(
|
||||
req_state.prompt_token_ids, req_state.prompt_embeds),
|
||||
max_tokens_param=req_state.max_tokens_param,
|
||||
req_stats=req_state.stats)
|
||||
self.lora_states.finish_request(req_state)
|
||||
|
||||
@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
@ -390,6 +391,16 @@ class Processor:
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
# Mypy does not always properly infer the types of some elements of
|
||||
# discriminated unions of TypedDicts, because of how it handles
|
||||
# inheritance of TypedDict. If we explicitly extract the items we want
|
||||
# we can avoid type errors from using `dict.get` later in the method.
|
||||
prompt_str: Optional[str] = None if decoder_inputs[
|
||||
"type"] == "embeds" else decoder_inputs.get("prompt")
|
||||
prompt_token_ids = decoder_inputs[
|
||||
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
@ -398,9 +409,10 @@ class Processor:
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
sampling_params.max_tokens = \
|
||||
self.model_config.max_model_len - seq_len
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
if self.tokenizer is not None:
|
||||
@ -430,9 +442,10 @@ class Processor:
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx]))
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
return prompt_str, EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
@ -461,10 +474,17 @@ class Processor:
|
||||
):
|
||||
model_config = self.model_config
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
prompt_ids = None if prompt_inputs[
|
||||
"type"] == "embeds" else prompt_inputs["prompt_token_ids"]
|
||||
prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[
|
||||
"type"] == "embeds" else None
|
||||
prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_ids, prompt_embeds)
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
elif prompt_inputs["type"] == "embeds":
|
||||
pass # Prompt embeds should not have prompt_ids.
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
@ -472,7 +492,7 @@ class Processor:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = self.tokenizer
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
max_input_id = max(prompt_ids or [], default=0)
|
||||
|
||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||
# self.model_config.get_vocab_size() is the model’s vocab size.
|
||||
@ -490,7 +510,7 @@ class Processor:
|
||||
f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
if prompt_len > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
@ -514,7 +534,7 @@ class Processor:
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"The {prompt_type} prompt (length {prompt_len}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
|
||||
@ -7,9 +7,12 @@ from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
@ -25,12 +28,13 @@ class Request:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt_token_ids: list[int],
|
||||
prompt_token_ids: Optional[list[int]],
|
||||
sampling_params: Optional[SamplingParams],
|
||||
pooling_params: Optional[PoolingParams],
|
||||
eos_token_id: Optional[int],
|
||||
client_index: int = 0,
|
||||
arrival_time: Optional[float] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
mm_features: Optional[list[MultiModalFeatureSpec]] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
@ -79,9 +83,13 @@ class Request:
|
||||
"sampling_params and pooling_params can't both be unset")
|
||||
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self.prompt_embeds = prompt_embeds
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
self._output_token_ids: list[int] = []
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
|
||||
self._all_token_ids: list[int] = self.prompt_token_ids.copy(
|
||||
) if self.prompt_token_ids is not None else [0
|
||||
] * self.num_prompt_tokens
|
||||
self.num_output_placeholders = 0 # Used in async scheduling.
|
||||
self.spec_token_ids: list[int] = []
|
||||
self.num_computed_tokens = 0
|
||||
@ -123,6 +131,7 @@ class Request:
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
|
||||
@ -243,7 +243,7 @@ class AdapterLogitsProcessor(LogitsProcessor):
|
||||
def _new_state(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt_ids: list[int],
|
||||
prompt_ids: Optional[list[int]],
|
||||
output_ids: list[int],
|
||||
) -> Optional[partial[torch.Tensor]]:
|
||||
"""Return state representation for new request
|
||||
|
||||
@ -187,7 +187,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
@staticmethod
|
||||
def add_request(
|
||||
params: SamplingParams, _: list[int], output_tok_ids: list[int]
|
||||
params: SamplingParams, _: Optional[list[int]],
|
||||
output_tok_ids: list[int]
|
||||
) -> Optional[tuple[int, Sequence[int], set[int]]]:
|
||||
min_tokens = params.min_tokens
|
||||
if not min_tokens or len(output_tok_ids) >= min_tokens:
|
||||
@ -234,7 +235,8 @@ class MinTokensLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def process_dict_updates(
|
||||
req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
|
||||
new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]]
|
||||
new_state: Callable[[SamplingParams, Optional[list[int]], list[int]],
|
||||
Optional[T]]
|
||||
) -> bool:
|
||||
"""Utility function to update dict state for sparse LogitsProcessors."""
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ RemovedRequest = int
|
||||
|
||||
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
|
||||
# requests added to the batch.
|
||||
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
|
||||
AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]]
|
||||
|
||||
# (index 1, index 2, directionality) tuples representing
|
||||
# one-way moves or two-way swaps of requests in batch
|
||||
|
||||
@ -174,7 +174,7 @@ class MsgpackEncoder:
|
||||
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
|
||||
assert self.aux_buffers is not None
|
||||
# view the tensor as a contiguous 1D array of bytes
|
||||
arr = obj.flatten().contiguous().view(torch.uint8).numpy()
|
||||
arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy()
|
||||
if obj.nbytes < self.size_threshold:
|
||||
# Smaller tensors are encoded inline, just like ndarrays.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
@ -29,7 +29,7 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
prompt_token_ids: Optional[list[int]]
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
@ -43,9 +43,11 @@ class CachedRequestState:
|
||||
mrope_position_delta: Optional[int] = None
|
||||
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
prompt_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_prompt_tokens = len(self.prompt_token_ids)
|
||||
self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
self.prompt_token_ids, self.prompt_embeds)
|
||||
|
||||
@property
|
||||
def num_tokens(self) -> int:
|
||||
@ -63,6 +65,10 @@ class CachedRequestState:
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
if self.prompt_token_ids is None:
|
||||
raise ValueError(
|
||||
f"Tried to access token index {idx}, but that token was "
|
||||
"provided via prompt_embeds, and its ID is unknown.")
|
||||
return self.prompt_token_ids[idx]
|
||||
elif idx - self.num_prompt_tokens < len(self.output_token_ids):
|
||||
return self.output_token_ids[idx - self.num_prompt_tokens]
|
||||
@ -109,6 +115,14 @@ class InputBatch:
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
self.req_prompt_embeds: dict[int, torch.Tensor] = {}
|
||||
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
@ -310,15 +324,23 @@ class InputBatch:
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
start_idx = num_prompt_tokens
|
||||
end_idx = start_idx + len(request.output_token_ids)
|
||||
if request.prompt_token_ids is not None:
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = True
|
||||
else:
|
||||
self.is_token_ids[req_index, :num_prompt_tokens] = False
|
||||
if request.prompt_embeds is not None:
|
||||
self.req_prompt_embeds[req_index] = request.prompt_embeds
|
||||
self.token_ids_cpu[req_index,
|
||||
start_idx:end_idx] = request.output_token_ids
|
||||
# Number of token ids in token_ids_cpu.
|
||||
self.is_token_ids[req_index, start_idx:end_idx] = True
|
||||
# Number of token ids in prompt (token_ids_cpu or prompt_embeds).
|
||||
# NOTE(woosuk): This may include spec decode tokens.
|
||||
self.num_tokens[req_index] = request.num_tokens
|
||||
# Number of tokens without spec decode tokens.
|
||||
@ -503,6 +525,20 @@ class InputBatch:
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...]
|
||||
|
||||
# Swap prompt embeddings if they exist
|
||||
embeds_i1 = self.req_prompt_embeds.get(i1)
|
||||
embeds_i2 = self.req_prompt_embeds.get(i2)
|
||||
if embeds_i1 is not None:
|
||||
self.req_prompt_embeds[i2] = embeds_i1
|
||||
else:
|
||||
self.req_prompt_embeds.pop(i2, None)
|
||||
if embeds_i2 is not None:
|
||||
self.req_prompt_embeds[i1] = embeds_i2
|
||||
else:
|
||||
self.req_prompt_embeds.pop(i1, None)
|
||||
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \
|
||||
@ -592,6 +628,11 @@ class InputBatch:
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[
|
||||
last_req_index, :num_tokens]
|
||||
if last_req_index in self.req_prompt_embeds:
|
||||
self.req_prompt_embeds[
|
||||
empty_index] = self.req_prompt_embeds.pop(last_req_index)
|
||||
self.num_tokens[empty_index] = num_tokens
|
||||
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
|
||||
last_req_index]
|
||||
|
||||
@ -56,7 +56,9 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
is_pin_memory_available,
|
||||
length_from_prompt_token_ids_or_embeds, round_up,
|
||||
supports_dynamo)
|
||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
@ -197,6 +199,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
cache_config.cache_dtype]
|
||||
|
||||
self.is_pooling_model = (model_config.runner_type == 'pooling')
|
||||
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
||||
self.is_multimodal_raw_input_only_model = (
|
||||
model_config.is_multimodal_raw_input_only_model)
|
||||
|
||||
@ -342,6 +345,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.hidden_size,
|
||||
dtype=self.dtype,
|
||||
numpy=False)
|
||||
self.is_token_ids = self._make_buffer(self.max_num_tokens,
|
||||
dtype=torch.bool)
|
||||
self.discard_request_indices = self._make_buffer(self.max_num_reqs,
|
||||
dtype=torch.int64)
|
||||
self.num_discarded_requests = 0
|
||||
@ -574,6 +579,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt_embeds=new_req_data.prompt_embeds,
|
||||
mm_features=new_req_data.mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
@ -819,6 +825,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if self.input_batch.prev_sampled_token_ids is None:
|
||||
# Normal scheduling case
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
return
|
||||
|
||||
# Async scheduling case, where some decode requests from the previous
|
||||
@ -844,6 +852,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# If not all requests are decodes from the last iteration,
|
||||
# We need to copy the input_ids_cpu to the GPU first.
|
||||
self.input_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.inputs_embeds.copy_to_gpu(total_num_scheduled_tokens)
|
||||
self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens)
|
||||
if num_commmon_tokens == 0:
|
||||
# No requests in common with the previous iteration
|
||||
# So input_ids_cpu will have all the input ids.
|
||||
@ -857,6 +867,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_batch.prev_sampled_token_ids[:num_commmon_tokens,
|
||||
0],
|
||||
non_blocking=True)
|
||||
self.is_token_ids.gpu[:num_commmon_tokens] = True
|
||||
return
|
||||
# Upload the index tensors asynchronously
|
||||
# so the scatter can be non-blocking.
|
||||
@ -947,14 +958,60 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# where M is the max_model_len.
|
||||
token_indices = (positions_np +
|
||||
req_indices * self.input_batch.token_ids_cpu.shape[1])
|
||||
token_indices_tensor = torch.from_numpy(token_indices)
|
||||
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
|
||||
0,
|
||||
torch.from_numpy(token_indices),
|
||||
token_indices_tensor,
|
||||
out=self.input_ids.cpu[:total_num_scheduled_tokens])
|
||||
is_token_ids = self.input_batch.is_token_ids.flatten()
|
||||
torch.index_select(
|
||||
is_token_ids,
|
||||
0,
|
||||
token_indices_tensor,
|
||||
out=self.is_token_ids.cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Because we did not pre-allocate a massive prompt_embeds CPU tensor on
|
||||
# the InputBatch, we need to fill in the prompt embeds into the expected
|
||||
# spots in the GpuModelRunner's pre-allocated prompt_embeds tensor.
|
||||
if self.input_batch.req_prompt_embeds:
|
||||
output_idx = 0
|
||||
for req_idx in range(num_reqs):
|
||||
num_sched = num_scheduled_tokens[req_idx]
|
||||
|
||||
# Skip if this request doesn't have embeddings
|
||||
if req_idx not in self.input_batch.req_prompt_embeds:
|
||||
output_idx += num_sched
|
||||
continue
|
||||
|
||||
# Skip if no tokens scheduled
|
||||
if num_sched <= 0:
|
||||
output_idx += num_sched
|
||||
continue
|
||||
|
||||
req_embeds = self.input_batch.req_prompt_embeds[req_idx]
|
||||
start_pos = self.input_batch.num_computed_tokens_cpu[req_idx]
|
||||
|
||||
# Skip if trying to read beyond available embeddings
|
||||
if start_pos >= req_embeds.shape[0]:
|
||||
output_idx += num_sched
|
||||
continue
|
||||
|
||||
# Copy available embeddings
|
||||
end_pos = start_pos + num_sched
|
||||
actual_end = min(end_pos, req_embeds.shape[0])
|
||||
actual_num_sched = actual_end - start_pos
|
||||
|
||||
if actual_num_sched > 0:
|
||||
self.inputs_embeds.cpu[output_idx:output_idx +
|
||||
actual_num_sched].copy_(
|
||||
req_embeds[start_pos:actual_end]
|
||||
)
|
||||
|
||||
output_idx += num_sched
|
||||
|
||||
self.input_batch.block_table.compute_slot_mapping(
|
||||
req_indices, positions_np)
|
||||
@ -1279,7 +1336,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.input_batch.num_computed_tokens_cpu[index]
|
||||
num_scheduled_tokens = \
|
||||
scheduler_output.num_scheduled_tokens[req_id]
|
||||
num_prompt_tokens = len(req.prompt_token_ids)
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
req.prompt_token_ids, req.prompt_embeds)
|
||||
|
||||
if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
|
||||
prompt_part_len = max(0,
|
||||
@ -1845,6 +1903,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
**self._init_model_kwargs(num_scheduled_tokens),
|
||||
**self._extract_mm_kwargs(scheduler_output),
|
||||
}
|
||||
elif (self.enable_prompt_embeds and get_pp_group().is_first_rank):
|
||||
# Get the input embeddings for the tokens that are not input embeds,
|
||||
# then put them into the appropriate positions.
|
||||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||||
# enabled, (a) not all requests will use prompt embeds, and (b)
|
||||
# after the initial prompt is processed, the rest of the generated
|
||||
# tokens will be token ids, it is not desirable to have the
|
||||
# embedding layer outside of the CUDA graph all the time. The v0
|
||||
# engine avoids this by "double compiling" the CUDA graph, once
|
||||
# with input_ids and again with inputs_embeds, for all num_tokens.
|
||||
# If a batch only has token ids, then including the embedding layer
|
||||
# in the CUDA graph will be more performant (like in the else case
|
||||
# below).
|
||||
token_ids_idx = self.is_token_ids.gpu[:num_scheduled_tokens] \
|
||||
.nonzero(as_tuple=False) \
|
||||
.squeeze(1)
|
||||
# Some tokens ids may need to become embeds
|
||||
if token_ids_idx.numel() > 0:
|
||||
token_ids = self.input_ids.gpu[token_ids_idx]
|
||||
tokens_to_embeds = self.model.get_input_embeddings(
|
||||
input_ids=token_ids)
|
||||
self.inputs_embeds.gpu[token_ids_idx] = tokens_to_embeds
|
||||
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_input_tokens]
|
||||
model_kwargs = self._init_model_kwargs(num_input_tokens)
|
||||
input_ids = None
|
||||
else:
|
||||
# For text-only models, we use token ids as input.
|
||||
# While it is possible to use embeddings as input just like the
|
||||
@ -2023,6 +2107,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.input_batch.token_ids_cpu[req_idx,
|
||||
start_idx:end_idx] = sampled_ids
|
||||
self.input_batch.is_token_ids[req_idx, start_idx:end_idx] = True
|
||||
self.input_batch.num_tokens_no_spec[req_idx] = end_idx
|
||||
self.input_batch.num_tokens[req_idx] = end_idx
|
||||
|
||||
@ -2570,6 +2655,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Get metadata for this request.
|
||||
request = self.requests[req_id]
|
||||
if request.prompt_token_ids is None:
|
||||
# Prompt logprobs is incompatible with prompt embeddings
|
||||
continue
|
||||
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
prompt_token_ids = torch.tensor(request.prompt_token_ids).to(
|
||||
self.device, non_blocking=True)
|
||||
@ -2922,6 +3011,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
**model_kwargs,
|
||||
**self._dummy_mm_kwargs(num_reqs),
|
||||
}
|
||||
elif self.enable_prompt_embeds:
|
||||
input_ids = None
|
||||
inputs_embeds = self.inputs_embeds.gpu[:num_tokens]
|
||||
model_kwargs = self._init_model_kwargs(num_tokens)
|
||||
else:
|
||||
input_ids = self.input_ids.gpu[:num_tokens]
|
||||
inputs_embeds = None
|
||||
|
||||
@ -9,7 +9,7 @@ import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||
@ -213,7 +213,9 @@ class InputBatch:
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
# Copy the prompt token ids and output token ids.
|
||||
num_prompt_tokens = len(request.prompt_token_ids)
|
||||
num_prompt_tokens = length_from_prompt_token_ids_or_embeds(
|
||||
request.prompt_token_ids, request.prompt_embeds)
|
||||
# TODO: copy prompt_embeds
|
||||
self.num_prompt_tokens[req_index] = num_prompt_tokens
|
||||
self.token_ids_cpu[
|
||||
req_index, :num_prompt_tokens] = request.prompt_token_ids
|
||||
|
||||
@ -387,6 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
prompt_embeds=new_req_data.prompt_embeds,
|
||||
mm_features=new_req_data.mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user