[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom 2025-05-02 03:06:39 -05:00 committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 691 additions and 113 deletions

View File

@ -787,7 +787,7 @@ class VllmRunner:
def get_inputs( def get_inputs(
self, self,
prompts: list[str], prompts: Union[list[str], list[torch.Tensor]],
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None, audios: Optional[PromptAudioInput] = None,
@ -809,16 +809,18 @@ class VllmRunner:
if audios is not None and (audio := audios[i]) is not None: if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio multi_modal_data["audio"] = audio
inputs.append( text_prompt_kwargs = {
TextPrompt(prompt=prompt, ("prompt" if isinstance(prompt, str) else "prompt_embeds"):
multi_modal_data=multi_modal_data prompt,
if multi_modal_data else None)) "multi_modal_data": multi_modal_data or None
}
inputs.append(TextPrompt(**text_prompt_kwargs))
return inputs return inputs
def generate( def generate(
self, self,
prompts: list[str], prompts: Union[list[str], list[torch.Tensor]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,
@ -844,7 +846,7 @@ class VllmRunner:
output_str = sample.text output_str = sample.text
output_ids = list(sample.token_ids) output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids) req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str) req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs)) outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs return outputs
@ -911,7 +913,7 @@ class VllmRunner:
def generate_greedy( def generate_greedy(
self, self,
prompts: list[str], prompts: Union[list[str], list[torch.Tensor]],
max_tokens: int, max_tokens: int,
images: Optional[PromptImageInput] = None, images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None, videos: Optional[PromptVideoInput] = None,

View File

@ -2,16 +2,18 @@
import time import time
from collections import deque from collections import deque
from typing import Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest # noqa import pytest # noqa
import torch
from torch import Use # noqa from torch import Use # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup from vllm.sequence import SequenceGroup, SequenceStatus
from .utils import (append_new_token, append_new_token_seq, from .utils import (append_new_token, append_new_token_seq,
append_new_token_seq_group, create_dummy_prompt, append_new_token_seq_group, create_dummy_prompt,
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
), "A partial prefix of C (4 tokens) should be prefilled, with the " ), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total." "then be rounded down to 2 tokens on block size, thus 6 tokens in total."
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
"""
Test that the scheduler does not schedule batches with prompt tokens and
prompt embeddings co-mingled.
"""
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
)
# the odd indexed inputs should be passed in via embeddings,
# evens via token_ids
seq_length = 7
embedding_size = 5
num_seqs = 11
seq_tokens: list[list[int]] = []
seq_embeds: list[Optional[torch.Tensor]] = []
for i in range(num_seqs):
if i % 2:
seq_tokens.append(list(range(seq_length)))
seq_embeds.append(None)
else:
seq_tokens.append([0] * seq_length)
seq_embeds.append(torch.rand(embedding_size))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens[i],
prompt_embeds=seq_embeds[i],
block_size=block_size)
for i in range(len(seq_tokens))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
unfinished_seq_groups = [
seq_group for _, seq_group in seq_and_seq_groups
if not seq_group.is_finished()
]
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) > 0
batch_is_prompt_embeds = out.scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
expected_scheduled_seq_groups = [
seq_group for seq_group in unfinished_seq_groups
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
]
# We should have as many scheduled groups as possible, without mixing
assert len(out.scheduled_seq_groups) == min(
max_seq_group, len(expected_scheduled_seq_groups))
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
batch_is_prompt_embeds
for scheduled_seq_group in out.scheduled_seq_groups)
# Finish the scheduled groups
for scheduled_seq_group in out.scheduled_seq_groups:
for seq in scheduled_seq_group.seq_group.seqs:
seq.status = SequenceStatus.FINISHED_STOPPED
scheduler.free_finished_seq_groups()

View File

@ -5,9 +5,11 @@ from collections import defaultdict
from collections.abc import Sequence as GenericSequence from collections.abc import Sequence as GenericSequence
from typing import Any, Optional from typing import Any, Optional
import torch
from vllm import SamplingParams from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import (Logprob, Sequence, SequenceGroup, from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata) SequenceGroupMetadata)
@ -19,6 +21,7 @@ def create_dummy_prompt(
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_tokens: Optional[list[int]] = None, prompt_tokens: Optional[list[int]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
min_tokens: int = 0, min_tokens: int = 0,
max_tokens: int = 16, max_tokens: int = 16,
) -> tuple[Sequence, SequenceGroup]: ) -> tuple[Sequence, SequenceGroup]:
@ -31,9 +34,13 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
inputs = token_inputs(
prompt_token_ids=prompt_tokens,
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
prompt_embeds=prompt_embeds)
prompt = Sequence( prompt = Sequence(
int(request_id), int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str), inputs=inputs,
block_size=block_size, block_size=block_size,
) )
seq_group = SequenceGroup( seq_group = SequenceGroup(

View File

@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import pytest import pytest
import torch import torch
@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
hf_outputs = hf_model.generate_greedy_logprobs_limit( hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
"VLLM_USE_V1") == "0" else None
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
if prompt_embeds is not None:
prompt_embeds.append(hf_model.model.get_input_embeddings()(
token_ids).squeeze(0))
with vllm_runner( with vllm_runner(
model, model,
tokenizer_name=model_info.tokenizer or model, tokenizer_name=model_info.tokenizer or model,
@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if prompt_embeds is not None:
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)
check_logprobs_close( check_logprobs_close(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
name_0="hf", name_0="hf",
name_1="vllm", name_1="vllm",
) )
if prompt_embeds is not None:
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)
if use_rocm_aiter: if use_rocm_aiter:
# this is to ensure that vllm engine # this is to ensure that vllm engine
# has deallocated the memory before running the next # has deallocated the memory before running the next

View File

@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module():
assert model_runner.attn_backend.__name__ == "TritonMLABackend" assert model_runner.attn_backend.__name__ == "TritonMLABackend"
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
def test_prepare_prompt(batch_size): @pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner( model_runner = _create_model_runner(
"facebook/opt-125m", "facebook/opt-125m",
max_num_batched_tokens=100000, max_num_batched_tokens=100000,
@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size):
seq_lens: list[int] = [] seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]} block_tables = {0: [1]}
expected_input_embeds_len = 0
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len)) if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size):
seq_group_metadata_list) seq_group_metadata_list)
input_tokens = model_input.input_tokens input_tokens = model_input.input_tokens
input_positions = model_input.input_positions input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping
@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size):
assert len(input_tokens) == sum(seq_lens) assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens) assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions) if expected_input_embeds_len == 0:
torch.testing.assert_close(input_tokens, input_positions)
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size):
torch.testing.assert_close(actual, expected) torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
def test_prepare_decode_cuda_graph(batch_size): @pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner( model_runner = _create_model_runner(
"facebook/opt-125m", "facebook/opt-125m",
seed=0, seed=0,
@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len) context_lens.append(context_len)
seq_data = SequenceData.from_seqs(range(context_len)) if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len))
output_embed = None
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished. # Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0) seq_data.append_token_id(1, 0, output_embed)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input = model_runner._prepare_model_input_tensors( model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list) seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = ( input_tokens = model_input.input_tokens
model_input.input_tokens, model_input.input_positions, input_positions = model_input.input_positions
model_input.attn_metadata, model_input.attn_metadata.slot_mapping) input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
slot_mapping = attn_metadata.slot_mapping
assert len(slot_mapping) == len(input_tokens) assert len(slot_mapping) == len(input_tokens)
expected_bs = model_runner.vllm_config.pad_for_cudagraph( expected_bs = model_runner.vllm_config.pad_for_cudagraph(
@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# block table's first index corresponds to each batch, meaning in # block table's first index corresponds to each batch, meaning in
# decoding it is each token. # decoding it is each token.
assert attn_metadata.block_tables.shape[0] == len(input_tokens) assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number. # Block table's second dim corresponds to each token's block number.
# It is padded up to # It is padded up to
assert attn_metadata.block_tables.shape[1] == ( assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch()) model_runner.get_max_block_per_batch())
@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size):
assert len(input_tokens) == expected_bs assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs assert len(input_positions) == expected_bs
torch.allclose(input_tokens, input_positions) if use_prompt_embeds:
expected_input_embeds_length = start_loc[-1]
assert len(input_embeds) == expected_input_embeds_length
assert expected_input_embeds_length <= expected_bs
else:
assert input_embeds is None
# Verify Sampling # Verify Sampling
expected_selected_token_indices = [] expected_selected_token_indices = []
@ -266,25 +307,27 @@ def test_empty_seq_group():
seq_group_metadata_list: list[SequenceGroupMetadata] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors( model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list) seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens, input_tokens = model_input.input_tokens
model_input.input_positions, input_positions = model_input.input_positions
model_input.attn_metadata, attn_metadata = model_input.attn_metadata
)
assert input_tokens is None assert input_tokens is None
assert input_positions is None assert input_positions is None
assert attn_metadata is None assert attn_metadata is None
model_input = model_runner._prepare_model_input_tensors( model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list) seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens, input_tokens = model_input.input_tokens
model_input.input_positions, input_positions = model_input.input_positions
model_input.attn_metadata, input_embeds = model_input.inputs_embeds
model_input.seq_lens, attn_metadata = model_input.attn_metadata
) return_seq_lens = model_input.seq_lens
assert input_tokens is None assert input_tokens is None
assert input_positions is None assert input_positions is None
assert input_embeds is None
assert attn_metadata is None assert attn_metadata is None
assert return_seq_lens is None assert return_seq_lens is None
@ -299,9 +342,15 @@ def distributed_init():
ensure_model_parallel_initialized(1, 1) ensure_model_parallel_initialized(1, 1)
@pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init): @pytest.mark.parametrize('use_prompt_embeds', [True, False])
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
distributed_init, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner( model_runner = _create_model_runner(
"facebook/opt-125m", "facebook/opt-125m",
seed=0, seed=0,
@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
block_tables = {0: [1]} block_tables = {0: [1]}
prefill_batch_size = batch_size // 2 prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size decode_batch_size = batch_size - prefill_batch_size
expected_input_embeds_len = 0
for i in range(prefill_batch_size): for i in range(prefill_batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len) seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len)) if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(seq_len), )
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=True, is_prompt=True,
@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size): for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1 context_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(context_len)) if use_prompt_embeds:
seq_data.append_token_id(1, 0) seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
# This also iterates the expected input_embeds, because the model
# needs both the input and output embeddings passed into together
expected_input_embeds_len += 1
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len), )
output_embed = None
assert len(seq_data.prompt_token_ids) == context_len
seq_data.append_token_id(1, 0, output_embed)
seq_data.update_num_computed_tokens(context_len) seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_metadata_list.append(seq_group_metadata) decode_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list) model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens, input_tokens = model_input.input_tokens
model_input.input_positions, input_positions = model_input.input_positions
model_input.attn_metadata, input_embeds = model_input.inputs_embeds
) attn_metadata = model_input.attn_metadata
prefill_meta_actual = attn_metadata.prefill_metadata prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata decode_meta_actual = attn_metadata.decode_metadata
@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_prefills == prefill_batch_size
assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens) assert attn_metadata.num_prefill_tokens == sum(seq_lens)
if expected_input_embeds_len == 0:
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
# Verify attn metadata is consistent. We don't need to test individual # Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above. # values here because they are tested above.

View File

@ -367,9 +367,17 @@ class FlashInferState(AttentionState):
# scheduled while CUDA graph mode is enabled. We don't run graph in that # scheduled while CUDA graph mode is enabled. We don't run graph in that
# case. # case.
if use_cuda_graph and is_decode: if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0] if model_input.inputs_embeds is None:
state = (self.runner.graph_runners[model_input.virtual_engine] batch_size = model_input.input_tokens.shape[0]
[batch_size].attn_state) state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, False)].attn_state)
else:
batch_size = model_input.inputs_embeds.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, True)].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
) )
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()

View File

@ -1071,6 +1071,7 @@ class Scheduler:
) )
ignored_seq_groups: List[SequenceGroup] = [] ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = []
using_prompt_embeds: bool = False
waiting_queue = self.waiting waiting_queue = self.waiting
@ -1138,6 +1139,15 @@ class Scheduler:
waiting_queue.popleft() waiting_queue.popleft()
continue continue
# We cannot mix sequence groups that use prompt embeds and
# those that do not.
if len(seq_groups) == 0:
using_prompt_embeds = seq_group.uses_prompt_embeds()
if using_prompt_embeds != seq_group.uses_prompt_embeds():
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0 lora_int_id = 0
if self.lora_enabled: if self.lora_enabled:
lora_int_id = seq_group.lora_int_id lora_int_id = seq_group.lora_int_id
@ -1295,17 +1305,39 @@ class Scheduler:
# Merge lists # Merge lists
num_prefill_groups = len(prefills.seq_groups) num_prefill_groups = len(prefills.seq_groups)
ignored_seq_groups_for_embeds = list[SequenceGroup]()
if num_prefill_groups > 0: if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
ignored_seq_groups_for_embeds.clear()
else: else:
scheduled_seq_groups = running_scheduled.decode_seq_groups scheduled_seq_groups = running_scheduled.decode_seq_groups
if len(scheduled_seq_groups) > 0:
using_prompt_embeds = scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
ignored_seq_groups_for_embeds.clear()
indices_ignored = list[int]()
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
if using_prompt_embeds !=\
schedule_seq_group.seq_group.uses_prompt_embeds():
ignored_seq_groups_for_embeds.append(
schedule_seq_group.seq_group)
indices_ignored.append(i)
if len(ignored_seq_groups_for_embeds) > 0:
scheduled_seq_groups = [
group for i, group in enumerate(scheduled_seq_groups)
if i not in indices_ignored
]
else:
ignored_seq_groups_for_embeds.clear()
scheduled_seq_groups.extend(swapped_in.decode_seq_groups) scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy) blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs( return SchedulerOutputs(

View File

@ -489,6 +489,14 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
# We use the -2 dimension (instead of 0) in case a batched input
# of batch size 1 is passed in.
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]
if self.tokenizer is not None: if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request) tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer) self._validate_token_prompt(prompt, tokenizer=tokenizer)

View File

@ -753,6 +753,12 @@ class LLMEngine:
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
if self.tokenizer is not None: if self.tokenizer is not None:
self._validate_token_prompt( self._validate_token_prompt(
prompt, prompt,
@ -1267,11 +1273,13 @@ class LLMEngine:
if self.scheduler_config.is_multi_step: if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens( is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0 ) == 0
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if not is_prefill_append: if not is_prefill_append:
seq_group.update_num_computed_tokens(1) seq_group.update_num_computed_tokens(1)
else: else:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
@ -2032,10 +2040,12 @@ class LLMEngine:
tokenizer = (None if self.tokenizer is None else tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request)) self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs["prompt_token_ids"] prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids: if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model: if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data pass # Mllama may have empty encoder inputs for text-only data
if prompt_inputs["type"] == "embeds":
pass
else: else:
raise ValueError(f"The {prompt_type} prompt cannot be empty") raise ValueError(f"The {prompt_type} prompt cannot be empty")

View File

@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sampling_params: SamplingParams) -> None: sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples] output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples]
output_embeds = [sample.output_embed for sample in valid_samples]
# Truncate to max_tokens if necessary. # Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new # Incrementally append tokens to the sequence, as if we had only one new
# token. # token.
for output_token_id, output_logprob in zip(output_token_ids, for output_token_id, output_logprob, output_embed in zip(
output_logprobs): output_token_ids, output_logprobs, output_embeds):
seq.append_token_id( seq.append_token_id(
token_id=output_token_id, token_id=output_token_id,
logprobs=output_logprob, logprobs=output_logprob,
token_embed=output_embed,
) )
if is_prefill_sampled_token: if is_prefill_sampled_token:

View File

@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
sample = outputs.samples[0] sample = outputs.samples[0]
seq = seq_group.first_seq seq = seq_group.first_seq
if not is_async: if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs) seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if sampling_params.detokenize and self.detokenizer: if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace( new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params) seq, sampling_params)

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt, TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext, from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry) InputRegistry)
@ -21,7 +21,9 @@ __all__ = [
"SingletonPrompt", "SingletonPrompt",
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"TokenInputs", "TokenInputs",
"EmbedsInputs",
"token_inputs", "token_inputs",
"embeds_inputs",
"DecoderOnlyInputs", "DecoderOnlyInputs",
"EncoderDecoderInputs", "EncoderDecoderInputs",
"ProcessorInputs", "ProcessorInputs",

View File

@ -2,6 +2,7 @@
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
""" """
SingletonPrompt = Union[str, TextPrompt, TokensPrompt] class EmbedsPrompt(TypedDict):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
""" """
Set of possible schemas for a single prompt: Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
Note that "singleton" is as opposed to a data structure Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort which encapsulates multiple prompts, i.e. of the sort
@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`) - A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
- A single data structure containing both an encoder and a decoder prompt - A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`) (:class:`ExplicitEncoderDecoderPrompt`)
""" """
@ -176,7 +186,27 @@ def token_inputs(
return inputs return inputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"] class EmbedsInputs(TypedDict):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
"""The type of inputs."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(
type="embeds",
prompt_embeds=prompt_embeds,
)
return inputs
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
""" """
The inputs in :class:`~vllm.LLMEngine` before they are The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor. passed to the model executor.
@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
"""The inputs for the decoder portion.""" """The inputs for the decoder portion."""
SingletonInputs = Union[TokenInputs, "MultiModalInputs"] SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
""" """
A processed :class:`SingletonPrompt` which can be passed to A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`. :class:`vllm.sequence.Sequence`.

View File

@ -6,8 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
content: TokensPrompt content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
content: EmbedsPrompt
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt( def parse_singleton_prompt(
prompt: SingletonPrompt, prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
if isinstance(prompt, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt) return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict): elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt: # Type ignores are because mypy does not correctly infer the TypedDicts
return ParsedTokensPrompt(type="tokens", # Pyright does succeed.
content=prompt) # type: ignore if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(
type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(
type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt: elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt) return ParsedTextPrompt(type="text", content=prompt)
raise TypeError(
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs, inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]: ) -> tuple[Optional[SingletonInputs], SingletonInputs]:

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
from typing_extensions import assert_never from typing_extensions import assert_never
from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs) ProcessorInputs, PromptType, SingletonInputs,
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, is_embeds_inputs,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt) is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -328,6 +331,10 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance * :class:`SingletonInputs` instance
""" """
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed) self._get_prompt_data(parsed)
@ -359,6 +366,8 @@ class InputPreprocessor:
cache_salt=cache_salt, cache_salt=cache_salt,
) )
assert_never(parsed)
async def _prompt_to_llm_inputs_async( async def _prompt_to_llm_inputs_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
@ -369,6 +378,9 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed) self._get_prompt_data(parsed)
@ -399,10 +411,34 @@ class InputPreprocessor:
cache_salt=cache_salt, cache_salt=cache_salt,
) )
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed)
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
encoder_inputs: SingletonInputs, encoder_inputs: Union[TokenInputs, MultiModalInputs],
decoder_inputs: Optional[SingletonInputs], decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token" if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"): or encoder_inputs["type"] == "multimodal"):
@ -410,6 +446,9 @@ class InputPreprocessor:
else: else:
assert_never(encoder_inputs) # type: ignore[arg-type] assert_never(encoder_inputs) # type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible
assert "prompt_token_ids" in encoder_inputs
if decoder_inputs is None: if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper": if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder. # For Whisper models, the text prompt should go to the decoder.
@ -441,7 +480,8 @@ class InputPreprocessor:
def _separate_enc_dec_inputs_from_mm_processor_outputs( def _separate_enc_dec_inputs_from_mm_processor_outputs(
self, self,
inputs: SingletonInputs, inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None, decoder_inputs_to_override: Optional[Union[TokenInputs,
MultiModalInputs]] = None,
) -> tuple[SingletonInputs, SingletonInputs]: ) -> tuple[SingletonInputs, SingletonInputs]:
""" """
For encoder/decoder models only: For encoder/decoder models only:
@ -540,6 +580,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
@ -555,9 +597,12 @@ class InputPreprocessor:
inputs)) inputs))
else: else:
encoder_inputs = inputs encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
@ -590,6 +635,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model: if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = ( encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
@ -605,9 +652,12 @@ class InputPreprocessor:
inputs)) inputs))
else: else:
encoder_inputs = inputs encoder_inputs = inputs
decoder_inputs = None decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs( def _build_decoder_only_llm_inputs(
@ -617,10 +667,15 @@ class InputPreprocessor:
) -> DecoderOnlyInputs: ) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token" if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"): or prompt_inputs["type"] == "multimodal"):
# Mypy does not do type inference well with typedicts and Literal
# values
assert not is_embeds_inputs(prompt_inputs)
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"], prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )
elif (prompt_inputs["type"] == "embeds"):
pass
else: else:
assert_never(prompt_inputs) # type: ignore[arg-type] assert_never(prompt_inputs) # type: ignore[arg-type]

View File

@ -110,6 +110,11 @@ class SamplerOutput(
# 'broadcasted' to all other PP ranks for next step. # 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu: Optional[torch.Tensor] = None sampled_token_ids_cpu: Optional[torch.Tensor] = None
# On-device tensor containing the sampled token embeddings (embeddings
# corresponding to the sampled token ids). Used when prompt embeddings are
# specified in lieu of prompt token ids or text.
sampled_token_embeds: Optional[torch.Tensor] = None
# Spec decode metrics populated by workers. # Spec decode metrics populated by workers.
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
@ -183,7 +188,7 @@ class Sampler(nn.Module):
# Whether or not the SamplerOutput should have on-device tensors # Whether or not the SamplerOutput should have on-device tensors
# containing the sampled token ids and probabilities. This is used by # containing the sampled token ids and probabilities. This is used by
# speculative decoding. # speculative decoding and when prompt embeddings are specified.
self.include_gpu_probs_tensor = False self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False self.should_modify_greedy_probs_inplace = False

View File

@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
_output_token_ids: array = msgspec.field( _output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
_prompt_embeds: Optional[torch.Tensor] = None
_output_embeds: Optional[torch.Tensor] = None
### The below fields should not be passed as an argument ### ### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0 _cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: tuple[int, _prompt_token_ids_tuple: tuple[int,
@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
_num_cached_tokens: int = 0 _num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL _stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list) _cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
_cached_all_token_embeds: Optional[torch.Tensor] = None
# It is used to get delta input. It is reset when `get_delta_and_reset` # It is used to get delta input. It is reset when `get_delta_and_reset`
# is called. # is called.
@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
def from_seqs( def from_seqs(
prompt_token_ids: GenericSequence[int], prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None, output_token_ids: Optional[GenericSequence[int]] = None,
*,
prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData": ) -> "SequenceData":
""" """
Construct a :class:`SequenceData` instance from prompt and output Construct a :class:`SequenceData` instance from prompt and output
@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
prompt_token_ids) prompt_token_ids)
if output_token_ids is None: if output_token_ids is None:
return SequenceData(prompt_token_ids_arr) return SequenceData(prompt_token_ids_arr,
_prompt_embeds=prompt_embeds)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids) output_token_ids)
return SequenceData(prompt_token_ids_arr, return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr) _output_token_ids=output_token_ids_arr,
_prompt_embeds=prompt_embeds)
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l" assert self._prompt_token_ids.typecode == "l"
@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
self._prompt_token_ids_tuple: tuple[int, ...] = tuple( self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids) self._prompt_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
if self._prompt_embeds is not None:
self._update_cached_all_token_embeds()
def _update_cached_all_tokens(self): def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array) assert isinstance(self._prompt_token_ids, array)
@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids) self._output_token_ids)
def _update_cached_all_token_embeds(self):
assert isinstance(self._prompt_embeds, torch.Tensor)
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
if self._output_embeds is not None:
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds, self._output_embeds), dim=0)
@property @property
def cumulative_logprob(self) -> float: def cumulative_logprob(self) -> float:
return self._cumulative_logprob return self._cumulative_logprob
@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
new_output_token_ids) new_output_token_ids)
self._update_cached_all_tokens() self._update_cached_all_tokens()
@property
def output_embeds(self) -> Optional[torch.Tensor]:
return self._output_embeds
@output_embeds.setter
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
self._output_token_embeds = new_output_token_embeds
self._update_cached_all_token_embeds()
@property @property
def output_token_ids_array(self) -> array: def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type. """Return the prompt token ids in array type.
@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
assert isinstance(self._output_token_ids, array) assert isinstance(self._output_token_ids, array)
return self._output_token_ids return self._output_token_ids
@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds
@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds
self._update_cached_all_token_embeds()
@property @property
def mrope_position_delta(self) -> Optional[int]: def mrope_position_delta(self) -> Optional[int]:
return self._mrope_position_delta return self._mrope_position_delta
@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
def mrope_position_delta(self, new_mrope_position_delta): def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta self._mrope_position_delta = new_mrope_position_delta
def append_token_id(self, token_id: int, logprob: float) -> None: def append_token_id(self,
token_id: int,
logprob: float,
token_embed: Optional[torch.Tensor] = None) -> None:
self._output_token_ids.append(token_id) self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id) self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id) self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob self._cumulative_logprob += logprob
if token_embed is not None:
# Do not pass in with batch or sequence dimensions
assert token_embed.ndim == 1
token_embed = token_embed.detach().cpu().unsqueeze(0)
if self._output_embeds is None:
self._output_embeds = token_embed
else:
self._output_embeds = torch.cat(
(self._output_embeds, token_embed), dim=0)
assert self._cached_all_token_embeds is not None
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds,
token_embed.to(device=self._cached_all_token_embeds.device)),
dim=0)
def get_len(self) -> int: def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids) return len(self._output_token_ids) + len(self._prompt_token_ids)
@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
def get_token_ids(self) -> list[int]: def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids return self._cached_all_token_ids
def get_token_embeddings(self) -> Optional[torch.Tensor]:
return self._cached_all_token_embeds
def get_prefix_token_ids( def get_prefix_token_ids(
self, num_tokens: int self, num_tokens: int
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"SequenceData(" return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, " f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds.shape="
f"{getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, " f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, " f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()})") f"get_num_computed_tokens={self.get_num_computed_tokens()})")
@ -425,7 +482,10 @@ class Sequence:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData.from_seqs(self.prompt_token_ids) self.data = SequenceData.from_seqs(
self.prompt_token_ids,
prompt_embeds=self.inputs["prompt_embeds"]
if self.inputs["type"] == "embeds" else None)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
@ -448,14 +508,20 @@ class Sequence:
@property @property
def prompt(self) -> Optional[str]: def prompt(self) -> Optional[str]:
if self.inputs["type"] == "embeds":
return None
return self.inputs.get("prompt") return self.inputs.get("prompt")
@property @property
def prompt_token_ids(self) -> list[int]: def prompt_token_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return [0] * len(self.inputs["prompt_embeds"])
return self.inputs["prompt_token_ids"] return self.inputs["prompt_token_ids"]
@property @property
def token_type_ids(self) -> list[int]: def token_type_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return []
return self.inputs.get("token_type_ids", []) return self.inputs.get("token_type_ids", [])
@property @property
@ -554,11 +620,14 @@ class Sequence:
"""Reset the sequence states for recomputation.""" """Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute() self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: dict[int, def append_token_id(self,
Logprob]) -> None: token_id: int,
logprobs: dict[int, Logprob],
token_embed: Optional[torch.Tensor] = None) -> None:
assert token_id in logprobs assert token_id in logprobs
self.output_logprobs.append(logprobs) self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob) self.data.append_token_id(token_id, logprobs[token_id].logprob,
token_embed)
def get_len(self) -> int: def get_len(self) -> int:
return self.data.get_len() return self.data.get_len()
@ -889,6 +958,10 @@ class SequenceGroup:
f"sampling_params={self.sampling_params}, " f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})") f"num_seqs={len(self.seqs)})")
def uses_prompt_embeds(self) -> bool:
"""Returns True if the sequence group uses input embeds."""
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
class SequenceGroupMetadataDelta( class SequenceGroupMetadataDelta(
msgspec.Struct, msgspec.Struct,
@ -1043,10 +1116,14 @@ class SequenceOutput(
parent_seq_id: int parent_seq_id: int
output_token: int output_token: int
logprobs: dict[int, Logprob] logprobs: dict[int, Logprob]
output_embed: Optional[torch.Tensor] = None
def __repr__(self) -> str: def __repr__(self) -> str:
output_embed_shape = \
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, " f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"logprobs={self.logprobs})") f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:

View File

@ -201,6 +201,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
if self.prompt_adapter_config is not None: if self.prompt_adapter_config is not None:
raise ValueError("TP1DraftModelRunner has no support for " raise ValueError("TP1DraftModelRunner has no support for "
"prompt_adapter_config") "prompt_adapter_config")
if model_input.inputs_embeds is not None:
raise ValueError("TP1DraftModelRunner has no support for "
"inputs_embeds")
if model_input.multi_modal_kwargs: if model_input.multi_modal_kwargs:
raise ValueError( raise ValueError(
"TP1DraftModelRunner has no support for multi_modal_kwargs" "TP1DraftModelRunner has no support for multi_modal_kwargs"
@ -242,9 +245,16 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
# Get model # Get model
if use_cuda_graph: if use_cuda_graph:
graph_batch_size = model_input.input_tokens.shape[0] if model_input.inputs_embeds is None:
model_executable = (self.graph_runners[model_input.virtual_engine] graph_batch_size = model_input.input_tokens.shape[0]
[graph_batch_size]) model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
if previous_hidden_states is not None: if previous_hidden_states is not None:
hidden_states = torch.cat([ hidden_states = torch.cat([
@ -281,6 +291,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
self.vllm_config): self.vllm_config):
hidden_states = model_executable( hidden_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
inputs_embeds=None,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,

View File

@ -282,7 +282,8 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
else: else:
count += 1 count += 1
seq.append_token_id(token_id, token_logprob.logprob) seq.append_token_id(token_id, token_logprob.logprob,
seq_output.output_embed)
seq.update_num_computed_tokens(1) seq.update_num_computed_tokens(1)
@staticmethod @staticmethod

View File

@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"encoder_input_tokens": self.encoder_input_tokens, "encoder_input_tokens": self.encoder_input_tokens,
"encoder_input_positions": self.encoder_input_positions, "encoder_input_positions": self.encoder_input_positions,
@ -172,10 +173,17 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
if (model_input.attn_metadata is not None if (model_input.attn_metadata is not None
and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.prefill_metadata is None
and model_input.attn_metadata.decode_metadata.use_cuda_graph): and model_input.attn_metadata.decode_metadata.use_cuda_graph):
assert model_input.input_tokens is not None if model_input.inputs_embeds is None:
graph_batch_size = model_input.input_tokens.shape[0] assert model_input.input_tokens is not None
model_executable = self.graph_runners[ graph_batch_size = model_input.input_tokens.shape[0]
model_input.virtual_engine][graph_batch_size] model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else: else:
model_executable = self.model model_executable = self.model
@ -189,6 +197,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
model_input.virtual_engine): model_input.virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions, positions=model_input.input_positions,
encoder_input_ids=model_input.encoder_input_tokens, encoder_input_ids=model_input.encoder_input_tokens,
encoder_positions=model_input.encoder_input_positions, encoder_positions=model_input.encoder_input_positions,

View File

@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
get_sampler)
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models import supports_lora, supports_multimodal
@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
additional fields. additional fields.
""" """
input_tokens: Optional[torch.Tensor] = None input_tokens: Optional[torch.Tensor] = None
inputs_embeds: Optional[torch.Tensor] = None
input_positions: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None
token_types: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None
seq_lens: Optional[List[int]] = None seq_lens: Optional[List[int]] = None
@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = { tensor_dict = {
"input_tokens": self.input_tokens, "input_tokens": self.input_tokens,
"inputs_embeds": self.inputs_embeds,
"input_positions": self.input_positions, "input_positions": self.input_positions,
"lora_requests": self.lora_requests, "lora_requests": self.lora_requests,
"lora_mapping": self.lora_mapping, "lora_mapping": self.lora_mapping,
@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
def simple_reinit(self): def simple_reinit(self):
self.input_tokens[0].clear() # type: ignore self.input_tokens[0].clear() # type: ignore
self.inputs_embeds = None # type: ignore
self.input_positions[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore
self.token_types[0].clear() # type: ignore self.token_types[0].clear() # type: ignore
self.mrope_input_positions = None # type: ignore self.mrope_input_positions = None # type: ignore
@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# Input tokens and positions. # Input tokens and positions.
input_tokens: Optional[List[List[int]]] = None, input_tokens: Optional[List[List[int]]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_positions: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None,
token_types: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None,
mrope_input_positions: Optional[List[List[List[int]]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None,
@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)): for seq_id in range(len(self.seq_ids)):
self.input_tokens[seq_id].clear() self.input_tokens[seq_id].clear()
self.inputs_embeds = inputs_embeds
if input_positions: if input_positions:
self.input_positions = input_positions self.input_positions = input_positions
else: else:
@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
else: else:
self.input_tokens = input_tokens or [] self.input_tokens = input_tokens or []
self.inputs_embeds = inputs_embeds
self.input_positions = input_positions or [] self.input_positions = input_positions or []
self.token_types = token_types or [] self.token_types = token_types or []
self.mrope_input_positions = mrope_input_positions or None self.mrope_input_positions = mrope_input_positions or None
@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.lora_index_mapping = [] self.lora_index_mapping = []
self.lora_prompt_mapping = [] self.lora_prompt_mapping = []
def __repr__(self) -> str:
return (f"InterDataForSeqGroup("
f"request_id={self.request_id}, "
f"seq_ids={self.seq_ids}, "
f"is_prompt={self.is_prompt}, "
f"block_tables={self.block_tables}, "
f"computed_block_nums={self.computed_block_nums}, "
f"n_seqs={self.n_seqs}, "
f"input_tokens={self.input_tokens}, "
f"inputs_embeds.shape="
f"{getattr(self.inputs_embeds, 'shape', None)}, "
f"input_positions={self.input_positions}, "
f"token_types={self.token_types}, "
f"mrope_input_positions={self.mrope_input_positions}, "
f"seq_lens={self.seq_lens}, "
f"orig_seq_lens={self.orig_seq_lens}, "
f"query_lens={self.query_lens}, "
f"context_lens={self.context_lens}, "
f"multi_modal_kwargs={self.multi_modal_kwargs}")
def gen_inter_data_builder(self, num_seqs: int): def gen_inter_data_builder(self, num_seqs: int):
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
request_id="", request_id="",
@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
context_len = seq_data.get_num_computed_tokens() context_len = seq_data.get_num_computed_tokens()
# Compute tokens. # Compute tokens.
tokens = seq_data.get_token_ids()[context_len:seq_len] if seq_data.prompt_embeds is None:
tokens = seq_data.get_token_ids()[context_len:seq_len]
prompt_embeds = None
else:
tokens = [0] * (seq_len - context_len)
prompt_embeds = seq_data.get_token_embeddings(
)[context_len:seq_len]
token_types = seq_group_metadata.token_type_ids token_types = seq_group_metadata.token_type_ids
inter_data.seq_lens[seq_idx] = seq_len inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.context_lens[seq_idx] = context_len inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
inter_data.token_types[seq_idx].extend( inter_data.token_types[seq_idx].extend(
token_types if token_types else []) token_types if token_types else [])
@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
create on-device tensors. create on-device tensors.
""" """
# Combine and flatten intermediate data. # Combine and flatten intermediate data.
input_tokens = [] input_tokens = list[int]()
token_types = [] inputs_embeds_lst = list[torch.Tensor]()
token_types = list[int]()
for inter_data in self.inter_data_list: for inter_data in self.inter_data_list:
for cur_input_tokens in inter_data.input_tokens: for cur_input_tokens in inter_data.input_tokens:
input_tokens.extend(cur_input_tokens) input_tokens.extend(cur_input_tokens)
for cur_token_types in inter_data.token_types: for cur_token_types in inter_data.token_types:
token_types.extend(cur_token_types) token_types.extend(cur_token_types)
if inter_data.inputs_embeds is not None:
inputs_embeds_lst.append(
inter_data.inputs_embeds.to(
dtype=self.runner.model_config.dtype,
device=self.runner.device))
inputs_embeds: Optional[torch.Tensor]
if len(inputs_embeds_lst) == 0:
inputs_embeds = None
else:
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
dtype=self.runner.model_config.dtype,
device=self.runner.device)
assert len(inputs_embeds) == len(input_tokens)
if not input_tokens: if not input_tokens and inputs_embeds is None:
# This may happen when all prefill requests hit # This may happen when all prefill requests hit
# prefix caching and there is no decode request. # prefix caching and there is no decode request.
return self.model_input_cls() return self.model_input_cls()
@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
return self.model_input_cls( return self.model_input_cls(
input_tokens=input_tokens_tensor, input_tokens=input_tokens_tensor,
inputs_embeds=inputs_embeds,
input_positions=input_positions_tensor, input_positions=input_positions_tensor,
token_types=token_types_tensor, token_types=token_types_tensor,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.max_batchsize_to_capture = \ self.max_batchsize_to_capture = \
self.vllm_config.compilation_config.max_capture_size self.vllm_config.compilation_config.max_capture_size
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ #
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size) {} for _ in range(self.parallel_config.pipeline_parallel_size)
] ]
self.graph_memory_pool: Optional[Tuple[ self.graph_memory_pool: Optional[Tuple[
@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
input_positions = torch.zeros(max_batch_size, input_positions = torch.zeros(max_batch_size,
dtype=torch.long, dtype=torch.long,
device=self.device) device=self.device)
inputs_embeds = torch.zeros(
(max_batch_size, self.model_config.get_hidden_size()),
dtype=self.model_config.dtype,
device=self.device)
if self.model_config.uses_mrope: if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device) (3, 1)).cuda(device=self.device)
@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
self.parallel_config.pipeline_parallel_size): self.parallel_config.pipeline_parallel_size):
# Only rank 0 should print progress bar during capture # We need to not only iterate over batch sizes, but also whether
cudagraph_capture_sizes = (tqdm( # to use inputs_embeds or not, hence we use the cartesian
self.vllm_config.compilation_config. # product.
cudagraph_capture_sizes = self.vllm_config.compilation_config\
.cudagraph_capture_sizes
cudagraph_inputs_embeds = (True, False)
compilation_cases = itertools.product(
cudagraph_capture_sizes, cudagraph_capture_sizes,
desc="Capturing CUDA graph shapes", cudagraph_inputs_embeds,
) if get_tensor_model_parallel_rank() == 0 else )
self.vllm_config.compilation_config. # Only rank 0 should print progress bar during capture
cudagraph_capture_sizes) if get_tensor_model_parallel_rank() == 0:
for batch_size in cudagraph_capture_sizes: compilation_cases = tqdm(
list(compilation_cases),
desc="Capturing CUDA graph shapes")
for batch_size, use_inputs_embeds in compilation_cases:
attn_metadata = ( attn_metadata = (
self.attn_state.graph_capture_get_metadata_for_batch( self.attn_state.graph_capture_get_metadata_for_batch(
batch_size, batch_size,
@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
capture_inputs = { capture_inputs = {
"input_ids": "input_ids":
input_tokens[:batch_size], input_tokens[:batch_size],
"inputs_embeds":
inputs_embeds[:batch_size]
if use_inputs_embeds else None,
"positions": "positions":
input_positions[..., :batch_size], input_positions[..., :batch_size],
"intermediate_inputs": "intermediate_inputs":
@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
virtual_engine): virtual_engine):
graph_runner.capture(**capture_inputs) graph_runner.capture(**capture_inputs)
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
self.graph_runners[virtual_engine][batch_size] = ( self.graph_runners[virtual_engine][(
graph_runner) batch_size, use_inputs_embeds)] = graph_runner
if self.lora_config: if self.lora_config:
self._remove_dummy_loras() self._remove_dummy_loras()
@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if prefill_meta is None and decode_meta.use_cuda_graph: if prefill_meta is None and decode_meta.use_cuda_graph:
assert model_input.input_tokens is not None assert model_input.input_tokens is not None
graph_batch_size = model_input.input_tokens.shape[0] graph_batch_size = model_input.input_tokens.shape[0]
model_executable = self.graph_runners[virtual_engine][ use_inputs_embeds = model_input.inputs_embeds is not None
graph_batch_size] model_executable = self.graph_runners[virtual_engine][(
graph_batch_size, use_inputs_embeds)]
if previous_hidden_states is not None: if previous_hidden_states is not None:
previous_hidden_states = torch.cat([ previous_hidden_states = torch.cat([
previous_hidden_states, previous_hidden_states,
@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
self.vllm_config, virtual_engine): self.vllm_config, virtual_engine):
hidden_or_intermediate_states = model_executable( hidden_or_intermediate_states = model_executable(
input_ids=model_input.input_tokens, input_ids=model_input.input_tokens,
inputs_embeds=model_input.inputs_embeds,
positions=model_input.input_positions, positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors, intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(multi_modal_kwargs, **MultiModalKwargs.as_kwargs(multi_modal_kwargs,
@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_input.async_callback() model_input.async_callback()
# Sample the next token. # Sample the next token.
assert isinstance(self.sampler, Sampler)
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = True
output: SamplerOutput = self.sampler( output: SamplerOutput = self.sampler(
logits=logits, logits=logits,
sampling_metadata=model_input.sampling_metadata, sampling_metadata=model_input.sampling_metadata,
@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
output.model_forward_time = (orig_model_forward_time + output.model_forward_time = (orig_model_forward_time +
model_forward_time) model_forward_time)
if model_input.inputs_embeds is not None:
self.sampler.include_gpu_probs_tensor = \
orig_include_gpu_probs_tensor
if output.sampled_token_ids is not None:
output.sampled_token_embeds = self.model.get_input_embeddings(
output.sampled_token_ids.squeeze(1))
for token_embed, sequence_group_output in zip(
output.sampled_token_embeds, output.outputs):
assert len(sequence_group_output.samples) == 1
sequence_group_output.samples[0].output_embed = token_embed
if self.return_hidden_states: if self.return_hidden_states:
# we only need to pass hidden states of most recent token # we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None assert model_input.sampling_metadata is not None
@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
def capture( def capture(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_inputs: Optional[IntermediateTensors], intermediate_inputs: Optional[IntermediateTensors],
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
for _ in range(_NUM_WARMUP_ITERS): for _ in range(_NUM_WARMUP_ITERS):
self.model( self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds,
positions=positions, positions=positions,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
output_hidden_or_intermediate_states = self.model( output_hidden_or_intermediate_states = self.model(
input_ids=input_ids, input_ids=input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
positions=positions, positions=positions,
intermediate_tensors=intermediate_inputs, intermediate_tensors=intermediate_inputs,
**kwargs, **kwargs,
@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
self.input_buffers = { self.input_buffers = {
"input_ids": "input_ids":
input_ids, input_ids,
**({
"inputs_embeds": inputs_embeds,
} if inputs_embeds is not None else {}),
"positions": "positions":
positions, positions,
"kv_caches": "kv_caches":
@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor],
positions: torch.Tensor, positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
**kwargs, **kwargs,
@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
# so the shape is not padded, we need to copy partial only # so the shape is not padded, we need to copy partial only
self.input_buffers["positions"][:positions.shape[0]].copy_( self.input_buffers["positions"][:positions.shape[0]].copy_(
positions, non_blocking=True) positions, non_blocking=True)
if inputs_embeds is not None:
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
inputs_embeds, non_blocking=True)
if self.backend_name != "NO_ATTENTION": if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_( self.input_buffers["slot_mapping"].copy_(

View File

@ -84,10 +84,17 @@ class PoolingModelRunner(
# explore how to leverage it. # explore how to leverage it.
if (prefill_meta is None and decode_meta is not None if (prefill_meta is None and decode_meta is not None
and decode_meta.use_cuda_graph): and decode_meta.use_cuda_graph):
assert model_input.input_tokens is not None if model_input.inputs_embeds is None:
graph_batch_size = model_input.input_tokens.shape[0] assert model_input.input_tokens is not None
model_executable = self.graph_runners[virtual_engine][ graph_batch_size = model_input.input_tokens.shape[0]
graph_batch_size] model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, False)])
else:
graph_batch_size = model_input.inputs_embeds.shape[0]
model_executable = (
self.graph_runners[model_input.virtual_engine][(
graph_batch_size, True)])
else: else:
model_executable = self.model model_executable = self.model