[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom 2025-05-02 03:06:39 -05:00 committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 691 additions and 113 deletions

View File

@ -787,7 +787,7 @@ class VllmRunner:
def get_inputs(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
@ -809,16 +809,18 @@ class VllmRunner:
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
inputs.append(
TextPrompt(prompt=prompt,
multi_modal_data=multi_modal_data
if multi_modal_data else None))
text_prompt_kwargs = {
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
prompt,
"multi_modal_data": multi_modal_data or None
}
inputs.append(TextPrompt(**text_prompt_kwargs))
return inputs
def generate(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
@ -844,7 +846,7 @@ class VllmRunner:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@ -911,7 +913,7 @@ class VllmRunner:
def generate_greedy(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,

View File

@ -2,16 +2,18 @@
import time
from collections import deque
from typing import Optional
from unittest.mock import MagicMock
import pytest # noqa
import torch
from torch import Use # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup
from vllm.sequence import SequenceGroup, SequenceStatus
from .utils import (append_new_token, append_new_token_seq,
append_new_token_seq_group, create_dummy_prompt,
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
"""
Test that the scheduler does not schedule batches with prompt tokens and
prompt embeddings co-mingled.
"""
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
)
# the odd indexed inputs should be passed in via embeddings,
# evens via token_ids
seq_length = 7
embedding_size = 5
num_seqs = 11
seq_tokens: list[list[int]] = []
seq_embeds: list[Optional[torch.Tensor]] = []
for i in range(num_seqs):
if i % 2:
seq_tokens.append(list(range(seq_length)))
seq_embeds.append(None)
else:
seq_tokens.append([0] * seq_length)
seq_embeds.append(torch.rand(embedding_size))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens[i],
prompt_embeds=seq_embeds[i],
block_size=block_size)
for i in range(len(seq_tokens))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
unfinished_seq_groups = [
seq_group for _, seq_group in seq_and_seq_groups
if not seq_group.is_finished()
]
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) > 0
batch_is_prompt_embeds = out.scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
expected_scheduled_seq_groups = [
seq_group for seq_group in unfinished_seq_groups
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
]
# We should have as many scheduled groups as possible, without mixing
assert len(out.scheduled_seq_groups) == min(
max_seq_group, len(expected_scheduled_seq_groups))
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
batch_is_prompt_embeds
for scheduled_seq_group in out.scheduled_seq_groups)
# Finish the scheduled groups
for scheduled_seq_group in out.scheduled_seq_groups:
for seq in scheduled_seq_group.seq_group.seqs:
seq.status = SequenceStatus.FINISHED_STOPPED
scheduler.free_finished_seq_groups()

View File

@ -5,9 +5,11 @@ from collections import defaultdict
from collections.abc import Sequence as GenericSequence
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
@ -19,6 +21,7 @@ def create_dummy_prompt(
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_tokens: Optional[list[int]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
min_tokens: int = 0,
max_tokens: int = 16,
) -> tuple[Sequence, SequenceGroup]:
@ -31,9 +34,13 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
inputs = token_inputs(
prompt_token_ids=prompt_tokens,
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
prompt_embeds=prompt_embeds)
prompt = Sequence(
int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
inputs=inputs,
block_size=block_size,
)
seq_group = SequenceGroup(

View File

@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import pytest
import torch
@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
"VLLM_USE_V1") == "0" else None
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
if prompt_embeds is not None:
prompt_embeds.append(hf_model.model.get_input_embeddings()(
token_ids).squeeze(0))
with vllm_runner(
model,
tokenizer_name=model_info.tokenizer or model,
@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if prompt_embeds is not None:
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
name_0="hf",
name_1="vllm",
)
if prompt_embeds is not None:
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)
if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next

View File

@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module():
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
max_num_batched_tokens=100000,
@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size):
seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
expected_input_embeds_len = 0
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size):
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size):
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions)
if expected_input_embeds_len == 0:
torch.testing.assert_close(input_tokens, input_positions)
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size):
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData.from_seqs(range(context_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len))
output_embed = None
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_data.append_token_id(1, 0, output_embed)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
slot_mapping = attn_metadata.slot_mapping
assert len(slot_mapping) == len(input_tokens)
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# Block table's second dim corresponds to each token's block number.
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size):
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
torch.allclose(input_tokens, input_positions)
if use_prompt_embeds:
expected_input_embeds_length = start_loc[-1]
assert len(input_embeds) == expected_input_embeds_length
assert expected_input_embeds_length <= expected_bs
else:
assert input_embeds is None
# Verify Sampling
expected_selected_token_indices = []
@ -266,25 +307,27 @@ def test_empty_seq_group():
seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
assert input_tokens is None
assert input_positions is None
assert input_embeds is None
assert attn_metadata is None
assert return_seq_lens is None
@ -299,9 +342,15 @@ def distributed_init():
ensure_model_parallel_initialized(1, 1)
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
distributed_init, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
expected_input_embeds_len = 0
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(seq_len), )
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.append_token_id(1, 0)
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
# This also iterates the expected input_embeds, because the model
# needs both the input and output embeddings passed into together
expected_input_embeds_len += 1
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len), )
output_embed = None
assert len(seq_data.prompt_token_ids) == context_len
seq_data.append_token_id(1, 0, output_embed)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert attn_metadata.num_prefills == prefill_batch_size
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
if expected_input_embeds_len == 0:
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.

View File

@ -367,9 +367,17 @@ class FlashInferState(AttentionState):
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
if model_input.inputs_embeds is None:
batch_size = model_input.input_tokens.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, False)].attn_state)
else:
batch_size = model_input.inputs_embeds.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, True)].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()

View File

@ -1071,6 +1071,7 @@ class Scheduler:
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
using_prompt_embeds: bool = False
waiting_queue = self.waiting
@ -1138,6 +1139,15 @@ class Scheduler:
waiting_queue.popleft()
continue
# We cannot mix sequence groups that use prompt embeds and
# those that do not.
if len(seq_groups) == 0:
using_prompt_embeds = seq_group.uses_prompt_embeds()
if using_prompt_embeds != seq_group.uses_prompt_embeds():
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
@ -1295,17 +1305,39 @@ class Scheduler:
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
ignored_seq_groups_for_embeds = list[SequenceGroup]()
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
ignored_seq_groups_for_embeds.clear()
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
if len(scheduled_seq_groups) > 0:
using_prompt_embeds = scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
ignored_seq_groups_for_embeds.clear()
indices_ignored = list[int]()
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
if using_prompt_embeds !=\
schedule_seq_group.seq_group.uses_prompt_embeds():
ignored_seq_groups_for_embeds.append(
schedule_seq_group.seq_group)
indices_ignored.append(i)
if len(ignored_seq_groups_for_embeds) > 0:
scheduled_seq_groups = [
group for i, group in enumerate(scheduled_seq_groups)
if i not in indices_ignored
]
else:
ignored_seq_groups_for_embeds.clear()
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs(

View File

@ -489,6 +489,14 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
# We use the -2 dimension (instead of 0) in case a batched input
# of batch size 1 is passed in.
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)

View File

@ -753,6 +753,12 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
@ -1267,11 +1273,13 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
@ -2032,10 +2040,12 @@ class LLMEngine:
tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs["prompt_token_ids"]
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
if prompt_inputs["type"] == "embeds":
pass
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")

View File

@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
output_embeds = [sample.output_embed for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
for output_token_id, output_logprob, output_embed in zip(
output_token_ids, output_logprobs, output_embeds):
seq.append_token_id(
token_id=output_token_id,
logprobs=output_logprob,
token_embed=output_embed,
)
if is_prefill_sampled_token:

View File

@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
@ -21,7 +21,9 @@ __all__ = [
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"EmbedsInputs",
"token_inputs",
"embeds_inputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",

View File

@ -2,6 +2,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
class EmbedsPrompt(TypedDict):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
@ -176,7 +186,27 @@ def token_inputs(
return inputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
class EmbedsInputs(TypedDict):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
"""The type of inputs."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(
type="embeds",
prompt_embeds=prompt_embeds,
)
return inputs
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
"""The inputs for the decoder portion."""
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.

View File

@ -6,8 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
content: EmbedsPrompt
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens",
content=prompt) # type: ignore
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(
type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(
type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
from vllm import envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, is_embeds_inputs,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__)
@ -328,6 +331,10 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@ -359,6 +366,8 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
@ -369,6 +378,9 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@ -399,10 +411,34 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
encoder_inputs: Union[TokenInputs, MultiModalInputs],
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
@ -410,6 +446,9 @@ class InputPreprocessor:
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible
assert "prompt_token_ids" in encoder_inputs
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
@ -441,7 +480,8 @@ class InputPreprocessor:
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
decoder_inputs_to_override: Optional[Union[TokenInputs,
MultiModalInputs]] = None,
) -> tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
@ -540,6 +580,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@ -555,9 +597,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
@ -590,6 +635,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@ -605,9 +652,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
@ -617,10 +667,15 @@ class InputPreprocessor:
) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
# Mypy does not do type inference well with typedicts and Literal
# values
assert not is_embeds_inputs(prompt_inputs)
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
elif (prompt_inputs["type"] == "embeds"):
pass
else:
assert_never(prompt_inputs) # type: ignore[arg-type]

View File

@ -110,6 +110,11 @@ class SamplerOutput(
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None
# On-device tensor containing the sampled token embeddings (embeddings
# corresponding to the sampled token ids). Used when prompt embeddings are
# specified in lieu of prompt token ids or text.
sampled_token_embeds: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
@ -183,7 +188,7 @@ class Sampler(nn.Module):
# Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by
# speculative decoding.
# speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False

View File

@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
_prompt_embeds: Optional[torch.Tensor] = None
_output_embeds: Optional[torch.Tensor] = None
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: tuple[int,
@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
_cached_all_token_embeds: Optional[torch.Tensor] = None
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
*,
prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance from prompt and output
@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
prompt_token_ids)
if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)
return SequenceData(prompt_token_ids_arr,
_prompt_embeds=prompt_embeds)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)
return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)
_output_token_ids=output_token_ids_arr,
_prompt_embeds=prompt_embeds)
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
if self._prompt_embeds is not None:
self._update_cached_all_token_embeds()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids)
def _update_cached_all_token_embeds(self):
assert isinstance(self._prompt_embeds, torch.Tensor)
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
if self._output_embeds is not None:
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds, self._output_embeds), dim=0)
@property
def cumulative_logprob(self) -> float:
return self._cumulative_logprob
@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
new_output_token_ids)
self._update_cached_all_tokens()
@property
def output_embeds(self) -> Optional[torch.Tensor]:
return self._output_embeds
@output_embeds.setter
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
self._output_token_embeds = new_output_token_embeds
self._update_cached_all_token_embeds()
@property
def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
assert isinstance(self._output_token_ids, array)
return self._output_token_ids
@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds
@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds
self._update_cached_all_token_embeds()
@property
def mrope_position_delta(self) -> Optional[int]:
return self._mrope_position_delta
@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta
def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self,
token_id: int,
logprob: float,
token_embed: Optional[torch.Tensor] = None) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
if token_embed is not None:
# Do not pass in with batch or sequence dimensions
assert token_embed.ndim == 1
token_embed = token_embed.detach().cpu().unsqueeze(0)
if self._output_embeds is None:
self._output_embeds = token_embed
else:
self._output_embeds = torch.cat(
(self._output_embeds, token_embed), dim=0)
assert self._cached_all_token_embeds is not None
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds,
token_embed.to(device=self._cached_all_token_embeds.device)),
dim=0)
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids
def get_token_embeddings(self) -> Optional[torch.Tensor]:
return self._cached_all_token_embeds
def get_prefix_token_ids(
self, num_tokens: int
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds.shape="
f"{getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
@ -425,7 +482,10 @@ class Sequence:
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.data = SequenceData.from_seqs(
self.prompt_token_ids,
prompt_embeds=self.inputs["prompt_embeds"]
if self.inputs["type"] == "embeds" else None)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@ -448,14 +508,20 @@ class Sequence:
@property
def prompt(self) -> Optional[str]:
if self.inputs["type"] == "embeds":
return None
return self.inputs.get("prompt")
@property
def prompt_token_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return [0] * len(self.inputs["prompt_embeds"])
return self.inputs["prompt_token_ids"]
@property
def token_type_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return []
return self.inputs.get("token_type_ids", [])
@property
@ -554,11 +620,14 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: dict[int,
Logprob]) -> None:
def append_token_id(self,
token_id: int,
logprobs: dict[int, Logprob],
token_embed: Optional[torch.Tensor] = None) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.data.append_token_id(token_id, logprobs[token_id].logprob,
token_embed)
def get_len(self) -> int:
return self.data.get_len()
@ -889,6 +958,10 @@ class SequenceGroup:
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})")
def uses_prompt_embeds(self) -> bool:
"""Returns True if the sequence group uses input embeds."""
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
class SequenceGroupMetadataDelta(
msgspec.Struct,
@ -1043,10 +1116,14 @@ class SequenceOutput(
parent_seq_id: int
output_token: int
logprobs: dict[int, Logprob]
output_embed: Optional[torch.Tensor] = None
def __repr__(self) -> str:
output_embed_shape = \
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool:

View File

@ -201,6 +201,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config")
if model_input.inputs_embeds is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"inputs_embeds")
if model_input.multi_modal_kwargs:
raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs"
@ -242,9 +245,16 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
# Get model
if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (self.graph_runners[model_input.virtual_engine]
[graph_batch_size])
if model_input.inputs_embeds is None:
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
if previous_hidden_states is not None:
hidden_states = torch.cat([
@ -281,6 +291,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=None,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,

View File

@ -282,7 +282,8 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
else:
count += 1
seq.append_token_id(token_id, token_logprob.logprob)
seq.append_token_id(token_id, token_logprob.logprob,
seq_output.output_embed)
seq.update_num_computed_tokens(1)
@staticmethod

View File

@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions,
@ -172,10 +173,17 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[
model_input.virtual_engine][graph_batch_size]
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model
@ -189,6 +197,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
model_input.virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions,

View File

@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal
@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
additional fields.
"""
input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None
@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
"input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions,
"lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping,
@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore
@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None,
@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions:
self.input_positions = input_positions
else:
@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else:
self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or []
self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None
@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping = []
self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="",
@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
context_len = seq_data.get_num_computed_tokens()
# Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len]
if seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend(
token_types if token_types else [])
@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors.
"""
# Combine and flatten intermediate data.
input_tokens = []
token_types = []
input_tokens = list[int]()
inputs_embeds_lst = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens:
if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit
# prefix caching and there is no decode request.
return self.model_input_cls()
@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls(
input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor,
token_types=token_types_tensor,
attn_metadata=attn_metadata,
@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
#
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.graph_memory_pool: Optional[Tuple[
@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
inputs_embeds = torch.zeros(
(max_batch_size, self.model_config.get_hidden_size()),
dtype=self.model_config.dtype,
device=self.device)
if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph.
for virtual_engine in range(
self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture
cudagraph_capture_sizes = (tqdm(
self.vllm_config.compilation_config.
# We need to not only iterate over batch sizes, but also whether
# to use inputs_embeds or not, hence we use the cartesian
# product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False)
compilation_cases = itertools.product(
cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes",
) if get_tensor_model_parallel_rank() == 0 else
self.vllm_config.compilation_config.
cudagraph_capture_sizes)
for batch_size in cudagraph_capture_sizes:
cudagraph_inputs_embeds,
)
# Only rank 0 should print progress bar during capture
if get_tensor_model_parallel_rank() == 0:
compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch(
batch_size,
@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
capture_inputs = {
"input_ids":
input_tokens[:batch_size],
"inputs_embeds":
inputs_embeds[:batch_size]
if use_inputs_embeds else None,
"positions":
input_positions[..., :batch_size],
"intermediate_inputs":
@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
virtual_engine):
graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = (
graph_runner)
self.graph_runners[virtual_engine][(
batch_size, use_inputs_embeds)] = graph_runner
if self.lora_config:
self._remove_dummy_loras()
@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
use_inputs_embeds = model_input.inputs_embeds is not None
model_executable = self.graph_runners[virtual_engine][(
graph_batch_size, use_inputs_embeds)]
if previous_hidden_states is not None:
previous_hidden_states = torch.cat([
previous_hidden_states,
@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback()
# Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time +
model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
def capture(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor],
@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model(
input_ids=input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
positions=positions,
intermediate_tensors=intermediate_inputs,
**kwargs,
@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers = {
"input_ids":
input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
"positions":
positions,
"kv_caches":
@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
**kwargs,
@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
# so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True)
if inputs_embeds is not None:
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
inputs_embeds, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(

View File

@ -84,10 +84,17 @@ class PoolingModelRunner(
# explore how to leverage it.
if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph):
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][
graph_batch_size]
if model_input.inputs_embeds is None:
assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else:
model_executable = self.model