mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 05:25:20 +08:00
[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
9e2de9b9e9
commit
cc2a77d7f1
@ -787,7 +787,7 @@ class VllmRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
@ -809,16 +809,18 @@ class VllmRunner:
|
||||
if audios is not None and (audio := audios[i]) is not None:
|
||||
multi_modal_data["audio"] = audio
|
||||
|
||||
inputs.append(
|
||||
TextPrompt(prompt=prompt,
|
||||
multi_modal_data=multi_modal_data
|
||||
if multi_modal_data else None))
|
||||
text_prompt_kwargs = {
|
||||
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
|
||||
prompt,
|
||||
"multi_modal_data": multi_modal_data or None
|
||||
}
|
||||
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||||
|
||||
return inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
@ -844,7 +846,7 @@ class VllmRunner:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append(prompt_str + output_str)
|
||||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
@ -911,7 +913,7 @@ class VllmRunner:
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: list[str],
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
|
||||
@ -2,16 +2,18 @@
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest # noqa
|
||||
import torch
|
||||
from torch import Use # noqa
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus
|
||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroup
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
from .utils import (append_new_token, append_new_token_seq,
|
||||
append_new_token_seq_group, create_dummy_prompt,
|
||||
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||
|
||||
|
||||
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
|
||||
"""
|
||||
Test that the scheduler does not schedule batches with prompt tokens and
|
||||
prompt embeddings co-mingled.
|
||||
"""
|
||||
block_size = 2
|
||||
max_seq_group = 3
|
||||
scheduler = initialize_scheduler(
|
||||
block_size=block_size,
|
||||
num_cpu_blocks=16,
|
||||
num_gpu_blocks=16,
|
||||
max_num_seqs=max_seq_group,
|
||||
max_model_len=100,
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
|
||||
# the odd indexed inputs should be passed in via embeddings,
|
||||
# evens via token_ids
|
||||
seq_length = 7
|
||||
embedding_size = 5
|
||||
num_seqs = 11
|
||||
seq_tokens: list[list[int]] = []
|
||||
seq_embeds: list[Optional[torch.Tensor]] = []
|
||||
for i in range(num_seqs):
|
||||
if i % 2:
|
||||
seq_tokens.append(list(range(seq_length)))
|
||||
seq_embeds.append(None)
|
||||
else:
|
||||
seq_tokens.append([0] * seq_length)
|
||||
seq_embeds.append(torch.rand(embedding_size))
|
||||
|
||||
seq_and_seq_groups = [
|
||||
create_dummy_prompt(f"{i}",
|
||||
prompt_tokens=seq_tokens[i],
|
||||
prompt_embeds=seq_embeds[i],
|
||||
block_size=block_size)
|
||||
for i in range(len(seq_tokens))
|
||||
]
|
||||
|
||||
for _, seq_group in seq_and_seq_groups:
|
||||
scheduler.add_seq_group(seq_group)
|
||||
|
||||
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
|
||||
unfinished_seq_groups = [
|
||||
seq_group for _, seq_group in seq_and_seq_groups
|
||||
if not seq_group.is_finished()
|
||||
]
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) > 0
|
||||
batch_is_prompt_embeds = out.scheduled_seq_groups[
|
||||
0].seq_group.uses_prompt_embeds()
|
||||
expected_scheduled_seq_groups = [
|
||||
seq_group for seq_group in unfinished_seq_groups
|
||||
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
|
||||
]
|
||||
|
||||
# We should have as many scheduled groups as possible, without mixing
|
||||
assert len(out.scheduled_seq_groups) == min(
|
||||
max_seq_group, len(expected_scheduled_seq_groups))
|
||||
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
|
||||
batch_is_prompt_embeds
|
||||
for scheduled_seq_group in out.scheduled_seq_groups)
|
||||
|
||||
# Finish the scheduled groups
|
||||
for scheduled_seq_group in out.scheduled_seq_groups:
|
||||
for seq in scheduled_seq_group.seq_group.seqs:
|
||||
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||
scheduler.free_finished_seq_groups()
|
||||
|
||||
@ -5,9 +5,11 @@ from collections import defaultdict
|
||||
from collections.abc import Sequence as GenericSequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
||||
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||
SequenceGroupMetadata)
|
||||
@ -19,6 +21,7 @@ def create_dummy_prompt(
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_tokens: Optional[list[int]] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
min_tokens: int = 0,
|
||||
max_tokens: int = 16,
|
||||
) -> tuple[Sequence, SequenceGroup]:
|
||||
@ -31,9 +34,13 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
inputs = token_inputs(
|
||||
prompt_token_ids=prompt_tokens,
|
||||
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
|
||||
prompt_embeds=prompt_embeds)
|
||||
prompt = Sequence(
|
||||
int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
inputs=inputs,
|
||||
block_size=block_size,
|
||||
)
|
||||
seq_group = SequenceGroup(
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
|
||||
"VLLM_USE_V1") == "0" else None
|
||||
prompt_token_ids = []
|
||||
for prompt in example_prompts:
|
||||
token_ids = hf_model.tokenizer(prompt,
|
||||
return_tensors="pt").input_ids.to(
|
||||
hf_model.model.device)
|
||||
prompt_token_ids.append(token_ids)
|
||||
if prompt_embeds is not None:
|
||||
prompt_embeds.append(hf_model.model.get_input_embeddings()(
|
||||
token_ids).squeeze(0))
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
tokenizer_name=model_info.tokenizer or model,
|
||||
@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if prompt_embeds is not None:
|
||||
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
|
||||
prompt_embeds, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
if prompt_embeds is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=vllm_outputs,
|
||||
outputs_1_lst=vllm_outputs_from_embeds,
|
||||
name_0="vllm",
|
||||
name_1="vllm_from_embeds",
|
||||
)
|
||||
|
||||
if use_rocm_aiter:
|
||||
# this is to ensure that vllm engine
|
||||
# has deallocated the memory before running the next
|
||||
|
||||
@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module():
|
||||
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
max_num_batched_tokens=100000,
|
||||
@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size):
|
||||
seq_lens: list[int] = []
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
block_tables = {0: [1]}
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size):
|
||||
seq_group_metadata_list)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size):
|
||||
|
||||
assert len(input_tokens) == sum(seq_lens)
|
||||
assert len(input_positions) == sum(seq_lens)
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
if expected_input_embeds_len == 0:
|
||||
torch.testing.assert_close(input_tokens, input_positions)
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size):
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
context_lens.append(context_len)
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len))
|
||||
output_embed = None
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
# Append one token ID since prefill is finished.
|
||||
seq_data.append_token_id(1, 0)
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=False,
|
||||
@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
||||
model_input.input_tokens, model_input.input_positions,
|
||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
|
||||
assert len(slot_mapping) == len(input_tokens)
|
||||
|
||||
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
||||
@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
# block table's first index corresponds to each batch, meaning in
|
||||
# decoding it is each token.
|
||||
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
||||
# Block table's second dim correspondsd to each token's block number.
|
||||
# Block table's second dim corresponds to each token's block number.
|
||||
# It is padded up to
|
||||
assert attn_metadata.block_tables.shape[1] == (
|
||||
model_runner.get_max_block_per_batch())
|
||||
@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
|
||||
assert len(input_tokens) == expected_bs
|
||||
assert len(input_positions) == expected_bs
|
||||
torch.allclose(input_tokens, input_positions)
|
||||
if use_prompt_embeds:
|
||||
expected_input_embeds_length = start_loc[-1]
|
||||
assert len(input_embeds) == expected_input_embeds_length
|
||||
assert expected_input_embeds_length <= expected_bs
|
||||
else:
|
||||
assert input_embeds is None
|
||||
|
||||
# Verify Sampling
|
||||
expected_selected_token_indices = []
|
||||
@ -266,25 +307,27 @@ def test_empty_seq_group():
|
||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert attn_metadata is None
|
||||
|
||||
model_input = model_runner._prepare_model_input_tensors(
|
||||
seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
model_input.seq_lens,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
return_seq_lens = model_input.seq_lens
|
||||
|
||||
assert input_tokens is None
|
||||
assert input_positions is None
|
||||
assert input_embeds is None
|
||||
assert attn_metadata is None
|
||||
assert return_seq_lens is None
|
||||
|
||||
@ -299,9 +342,15 @@ def distributed_init():
|
||||
ensure_model_parallel_initialized(1, 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
|
||||
distributed_init, monkeypatch):
|
||||
if use_prompt_embeds:
|
||||
# Prompt Embeddings is only currently supported on V0
|
||||
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
seed=0,
|
||||
@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
block_tables = {0: [1]}
|
||||
prefill_batch_size = batch_size // 2
|
||||
decode_batch_size = batch_size - prefill_batch_size
|
||||
expected_input_embeds_len = 0
|
||||
for i in range(prefill_batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
seq_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_lens.append(seq_len)
|
||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * seq_len,
|
||||
prompt_embeds=torch.rand(seq_len, 10),
|
||||
)
|
||||
expected_input_embeds_len += seq_len
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(seq_len), )
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
is_prompt=True,
|
||||
@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
for i in range(prefill_batch_size, batch_size):
|
||||
# make sure all tokens fit into one block
|
||||
context_len = i % (model_runner.block_size - 1) + 1
|
||||
seq_data = SequenceData.from_seqs(range(context_len))
|
||||
seq_data.append_token_id(1, 0)
|
||||
if use_prompt_embeds:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=[0] * context_len,
|
||||
prompt_embeds=torch.rand(context_len, 10),
|
||||
)
|
||||
output_embed = torch.rand(10)
|
||||
# This also iterates the expected input_embeds, because the model
|
||||
# needs both the input and output embeddings passed into together
|
||||
expected_input_embeds_len += 1
|
||||
else:
|
||||
seq_data = SequenceData.from_seqs(
|
||||
prompt_token_ids=range(context_len), )
|
||||
output_embed = None
|
||||
assert len(seq_data.prompt_token_ids) == context_len
|
||||
seq_data.append_token_id(1, 0, output_embed)
|
||||
seq_data.update_num_computed_tokens(context_len)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=f"test_{i}",
|
||||
@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
decode_metadata_list.append(seq_group_metadata)
|
||||
|
||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||
(input_tokens, input_positions, attn_metadata) = (
|
||||
model_input.input_tokens,
|
||||
model_input.input_positions,
|
||||
model_input.attn_metadata,
|
||||
)
|
||||
|
||||
input_tokens = model_input.input_tokens
|
||||
input_positions = model_input.input_positions
|
||||
input_embeds = model_input.inputs_embeds
|
||||
attn_metadata = model_input.attn_metadata
|
||||
|
||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||
decode_meta_actual = attn_metadata.decode_metadata
|
||||
@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
assert attn_metadata.num_prefills == prefill_batch_size
|
||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||
if expected_input_embeds_len == 0:
|
||||
assert input_embeds is None
|
||||
else:
|
||||
assert len(input_embeds) == expected_input_embeds_len
|
||||
|
||||
# Verify attn metadata is consistent. We don't need to test individual
|
||||
# values here because they are tested above.
|
||||
|
||||
@ -367,9 +367,17 @@ class FlashInferState(AttentionState):
|
||||
# scheduled while CUDA graph mode is enabled. We don't run graph in that
|
||||
# case.
|
||||
if use_cuda_graph and is_decode:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
state = (self.runner.graph_runners[model_input.virtual_engine]
|
||||
[batch_size].attn_state)
|
||||
if model_input.inputs_embeds is None:
|
||||
batch_size = model_input.input_tokens.shape[0]
|
||||
state = (
|
||||
self.runner.graph_runners[model_input.virtual_engine][(
|
||||
batch_size, False)].attn_state)
|
||||
else:
|
||||
batch_size = model_input.inputs_embeds.shape[0]
|
||||
state = (
|
||||
self.runner.graph_runners[model_input.virtual_engine][(
|
||||
batch_size, True)].attn_state)
|
||||
|
||||
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
||||
)
|
||||
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
||||
|
||||
@ -1071,6 +1071,7 @@ class Scheduler:
|
||||
)
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
seq_groups: List[ScheduledSequenceGroup] = []
|
||||
using_prompt_embeds: bool = False
|
||||
|
||||
waiting_queue = self.waiting
|
||||
|
||||
@ -1138,6 +1139,15 @@ class Scheduler:
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
|
||||
# We cannot mix sequence groups that use prompt embeds and
|
||||
# those that do not.
|
||||
if len(seq_groups) == 0:
|
||||
using_prompt_embeds = seq_group.uses_prompt_embeds()
|
||||
if using_prompt_embeds != seq_group.uses_prompt_embeds():
|
||||
leftover_waiting_sequences.appendleft(seq_group)
|
||||
waiting_queue.popleft()
|
||||
continue
|
||||
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
lora_int_id = seq_group.lora_int_id
|
||||
@ -1295,17 +1305,39 @@ class Scheduler:
|
||||
|
||||
# Merge lists
|
||||
num_prefill_groups = len(prefills.seq_groups)
|
||||
ignored_seq_groups_for_embeds = list[SequenceGroup]()
|
||||
if num_prefill_groups > 0:
|
||||
scheduled_seq_groups = prefills.seq_groups
|
||||
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
|
||||
ignored_seq_groups_for_embeds.clear()
|
||||
else:
|
||||
scheduled_seq_groups = running_scheduled.decode_seq_groups
|
||||
if len(scheduled_seq_groups) > 0:
|
||||
using_prompt_embeds = scheduled_seq_groups[
|
||||
0].seq_group.uses_prompt_embeds()
|
||||
ignored_seq_groups_for_embeds.clear()
|
||||
indices_ignored = list[int]()
|
||||
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
|
||||
if using_prompt_embeds !=\
|
||||
schedule_seq_group.seq_group.uses_prompt_embeds():
|
||||
ignored_seq_groups_for_embeds.append(
|
||||
schedule_seq_group.seq_group)
|
||||
indices_ignored.append(i)
|
||||
if len(ignored_seq_groups_for_embeds) > 0:
|
||||
scheduled_seq_groups = [
|
||||
group for i, group in enumerate(scheduled_seq_groups)
|
||||
if i not in indices_ignored
|
||||
]
|
||||
else:
|
||||
ignored_seq_groups_for_embeds.clear()
|
||||
|
||||
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
|
||||
|
||||
blocks_to_copy = running_scheduled.blocks_to_copy
|
||||
blocks_to_copy.extend(swapped_in.blocks_to_copy)
|
||||
|
||||
ignored_seq_groups = prefills.ignored_seq_groups
|
||||
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
|
||||
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
|
||||
|
||||
return SchedulerOutputs(
|
||||
|
||||
@ -489,6 +489,14 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if (isinstance(prompt, dict)
|
||||
and prompt.get("prompt_embeds", None) is not None
|
||||
and not prompt.get("prompt_token_ids", None)):
|
||||
# We use the -2 dimension (instead of 0) in case a batched input
|
||||
# of batch size 1 is passed in.
|
||||
prompt["prompt_token_ids"] = [0
|
||||
] * prompt["prompt_embeds"].shape[-2]
|
||||
|
||||
if self.tokenizer is not None:
|
||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||
|
||||
@ -753,6 +753,12 @@ class LLMEngine:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if (isinstance(prompt, dict)
|
||||
and prompt.get("prompt_embeds", None) is not None
|
||||
and not prompt.get("prompt_token_ids", None)):
|
||||
seq_len = prompt["prompt_embeds"].shape[0]
|
||||
prompt["prompt_token_ids"] = [0] * seq_len
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self._validate_token_prompt(
|
||||
prompt,
|
||||
@ -1267,11 +1273,13 @@ class LLMEngine:
|
||||
if self.scheduler_config.is_multi_step:
|
||||
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||
) == 0
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
if not is_prefill_append:
|
||||
seq_group.update_num_computed_tokens(1)
|
||||
else:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
|
||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
@ -2032,10 +2040,12 @@ class LLMEngine:
|
||||
tokenizer = (None if self.tokenizer is None else
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
if prompt_inputs["type"] == "embeds":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
|
||||
@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
sampling_params: SamplingParams) -> None:
|
||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||
output_embeds = [sample.output_embed for sample in valid_samples]
|
||||
|
||||
# Truncate to max_tokens if necessary.
|
||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||
@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
||||
# Incrementally append tokens to the sequence, as if we had only one new
|
||||
# token.
|
||||
for output_token_id, output_logprob in zip(output_token_ids,
|
||||
output_logprobs):
|
||||
for output_token_id, output_logprob, output_embed in zip(
|
||||
output_token_ids, output_logprobs, output_embeds):
|
||||
seq.append_token_id(
|
||||
token_id=output_token_id,
|
||||
logprobs=output_logprob,
|
||||
token_embed=output_embed,
|
||||
)
|
||||
|
||||
if is_prefill_sampled_token:
|
||||
|
||||
@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
||||
sample = outputs.samples[0]
|
||||
seq = seq_group.first_seq
|
||||
if not is_async:
|
||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
||||
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||
sample.output_embed)
|
||||
if sampling_params.detokenize and self.detokenizer:
|
||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||
seq, sampling_params)
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||
InputRegistry)
|
||||
@ -21,7 +21,9 @@ __all__ = [
|
||||
"SingletonPrompt",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"TokenInputs",
|
||||
"EmbedsInputs",
|
||||
"token_inputs",
|
||||
"embeds_inputs",
|
||||
"DecoderOnlyInputs",
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||
class EmbedsPrompt(TypedDict):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||
"""
|
||||
Set of possible schemas for a single prompt:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||
|
||||
Note that "singleton" is as opposed to a data structure
|
||||
which encapsulates multiple prompts, i.e. of the sort
|
||||
@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||
- A single data structure containing both an encoder and a decoder prompt
|
||||
(:class:`ExplicitEncoderDecoderPrompt`)
|
||||
"""
|
||||
@ -176,7 +186,27 @@ def token_inputs(
|
||||
return inputs
|
||||
|
||||
|
||||
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
|
||||
class EmbedsInputs(TypedDict):
|
||||
"""Represents embeddings-based inputs."""
|
||||
|
||||
type: Literal["embeds"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
|
||||
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
|
||||
"""Construct :class:`EmbedsInputs` from optional values."""
|
||||
inputs = EmbedsInputs(
|
||||
type="embeds",
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
|
||||
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
|
||||
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
|
||||
@ -6,8 +6,9 @@ from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
class ParsedEmbedsPrompt(TypedDict):
|
||||
type: Literal['embeds']
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||
...
|
||||
|
||||
|
||||
def parse_singleton_prompt(
|
||||
prompt: SingletonPrompt,
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
ParsedEmbedsPrompt]:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
if "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(type="tokens",
|
||||
content=prompt) # type: ignore
|
||||
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||
# Pyright does succeed.
|
||||
if "prompt_embeds" in prompt:
|
||||
return ParsedEmbedsPrompt(
|
||||
type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(
|
||||
type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt" in prompt:
|
||||
return ParsedTextPrompt(type="text", content=prompt)
|
||||
|
||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||
raise TypeError(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||
|
||||
|
||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_embeds" in prompt
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
|
||||
return isinstance(inputs, dict) and inputs["type"] == "embeds"
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
||||
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, is_embeds_inputs,
|
||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -328,6 +331,10 @@ class InputPreprocessor:
|
||||
* :class:`SingletonInputs` instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@ -359,6 +366,8 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
@ -369,6 +378,9 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@ -399,10 +411,34 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
def _process_prompt_embeds(self,
|
||||
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
|
||||
prompt_embeds_content = parsed["content"]
|
||||
|
||||
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
||||
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
) -> EncoderDecoderInputs:
|
||||
if (encoder_inputs["type"] == "token"
|
||||
or encoder_inputs["type"] == "multimodal"):
|
||||
@ -410,6 +446,9 @@ class InputPreprocessor:
|
||||
else:
|
||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||
|
||||
# Mypy does not correctly infer that EmbedsInputs is impossible
|
||||
assert "prompt_token_ids" in encoder_inputs
|
||||
|
||||
if decoder_inputs is None:
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper models, the text prompt should go to the decoder.
|
||||
@ -441,7 +480,8 @@ class InputPreprocessor:
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
||||
MultiModalInputs]] = None,
|
||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@ -540,6 +580,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@ -555,9 +597,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
@ -590,6 +635,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@ -605,9 +652,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
@ -617,10 +667,15 @@ class InputPreprocessor:
|
||||
) -> DecoderOnlyInputs:
|
||||
if (prompt_inputs["type"] == "token"
|
||||
or prompt_inputs["type"] == "multimodal"):
|
||||
# Mypy does not do type inference well with typedicts and Literal
|
||||
# values
|
||||
assert not is_embeds_inputs(prompt_inputs)
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
elif (prompt_inputs["type"] == "embeds"):
|
||||
pass
|
||||
else:
|
||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@ -110,6 +110,11 @@ class SamplerOutput(
|
||||
# 'broadcasted' to all other PP ranks for next step.
|
||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# On-device tensor containing the sampled token embeddings (embeddings
|
||||
# corresponding to the sampled token ids). Used when prompt embeddings are
|
||||
# specified in lieu of prompt token ids or text.
|
||||
sampled_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# Spec decode metrics populated by workers.
|
||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||
|
||||
@ -183,7 +188,7 @@ class Sampler(nn.Module):
|
||||
|
||||
# Whether or not the SamplerOutput should have on-device tensors
|
||||
# containing the sampled token ids and probabilities. This is used by
|
||||
# speculative decoding.
|
||||
# speculative decoding and when prompt embeddings are specified.
|
||||
self.include_gpu_probs_tensor = False
|
||||
self.should_modify_greedy_probs_inplace = False
|
||||
|
||||
|
||||
@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
|
||||
_output_token_ids: array = msgspec.field(
|
||||
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
||||
|
||||
_prompt_embeds: Optional[torch.Tensor] = None
|
||||
_output_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
### The below fields should not be passed as an argument ###
|
||||
_cumulative_logprob: float = 0.0
|
||||
_prompt_token_ids_tuple: tuple[int,
|
||||
@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
|
||||
_num_cached_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
|
||||
_cached_all_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||
# is called.
|
||||
@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
|
||||
def from_seqs(
|
||||
prompt_token_ids: GenericSequence[int],
|
||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||
*,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
) -> "SequenceData":
|
||||
"""
|
||||
Construct a :class:`SequenceData` instance from prompt and output
|
||||
@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
|
||||
prompt_token_ids)
|
||||
|
||||
if output_token_ids is None:
|
||||
return SequenceData(prompt_token_ids_arr)
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_prompt_embeds=prompt_embeds)
|
||||
|
||||
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
output_token_ids)
|
||||
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_output_token_ids=output_token_ids_arr)
|
||||
_output_token_ids=output_token_ids_arr,
|
||||
_prompt_embeds=prompt_embeds)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
|
||||
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
if self._prompt_embeds is not None:
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
assert isinstance(self._prompt_token_ids, array)
|
||||
@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
|
||||
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
def _update_cached_all_token_embeds(self):
|
||||
assert isinstance(self._prompt_embeds, torch.Tensor)
|
||||
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
|
||||
if self._output_embeds is not None:
|
||||
self._cached_all_token_embeds = torch.cat(
|
||||
(self._cached_all_token_embeds, self._output_embeds), dim=0)
|
||||
|
||||
@property
|
||||
def cumulative_logprob(self) -> float:
|
||||
return self._cumulative_logprob
|
||||
@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
|
||||
new_output_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def output_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self._output_embeds
|
||||
|
||||
@output_embeds.setter
|
||||
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
|
||||
self._output_token_embeds = new_output_token_embeds
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
@property
|
||||
def output_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
return self._output_token_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self._prompt_embeds
|
||||
|
||||
@prompt_embeds.setter
|
||||
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
@property
|
||||
def mrope_position_delta(self) -> Optional[int]:
|
||||
return self._mrope_position_delta
|
||||
@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
|
||||
def mrope_position_delta(self, new_mrope_position_delta):
|
||||
self._mrope_position_delta = new_mrope_position_delta
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
def append_token_id(self,
|
||||
token_id: int,
|
||||
logprob: float,
|
||||
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
self._cumulative_logprob += logprob
|
||||
if token_embed is not None:
|
||||
# Do not pass in with batch or sequence dimensions
|
||||
assert token_embed.ndim == 1
|
||||
token_embed = token_embed.detach().cpu().unsqueeze(0)
|
||||
if self._output_embeds is None:
|
||||
self._output_embeds = token_embed
|
||||
else:
|
||||
self._output_embeds = torch.cat(
|
||||
(self._output_embeds, token_embed), dim=0)
|
||||
assert self._cached_all_token_embeds is not None
|
||||
self._cached_all_token_embeds = torch.cat(
|
||||
(self._cached_all_token_embeds,
|
||||
token_embed.to(device=self._cached_all_token_embeds.device)),
|
||||
dim=0)
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
||||
@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
|
||||
def get_token_ids(self) -> list[int]:
|
||||
return self._cached_all_token_ids
|
||||
|
||||
def get_token_embeddings(self) -> Optional[torch.Tensor]:
|
||||
return self._cached_all_token_embeds
|
||||
|
||||
def get_prefix_token_ids(
|
||||
self, num_tokens: int
|
||||
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
|
||||
@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt_token_ids={self._prompt_token_ids}, "
|
||||
f"prompt_embeds.shape="
|
||||
f"{getattr(self._prompt_embeds, 'shape', None)}, "
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
|
||||
@ -425,7 +482,10 @@ class Sequence:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.data = SequenceData.from_seqs(
|
||||
self.prompt_token_ids,
|
||||
prompt_embeds=self.inputs["prompt_embeds"]
|
||||
if self.inputs["type"] == "embeds" else None)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@ -448,14 +508,20 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return None
|
||||
return self.inputs.get("prompt")
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return [0] * len(self.inputs["prompt_embeds"])
|
||||
return self.inputs["prompt_token_ids"]
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> list[int]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return []
|
||||
return self.inputs.get("token_type_ids", [])
|
||||
|
||||
@property
|
||||
@ -554,11 +620,14 @@ class Sequence:
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_state_for_recompute()
|
||||
|
||||
def append_token_id(self, token_id: int, logprobs: dict[int,
|
||||
Logprob]) -> None:
|
||||
def append_token_id(self,
|
||||
token_id: int,
|
||||
logprobs: dict[int, Logprob],
|
||||
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||
assert token_id in logprobs
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob,
|
||||
token_embed)
|
||||
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
@ -889,6 +958,10 @@ class SequenceGroup:
|
||||
f"sampling_params={self.sampling_params}, "
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
|
||||
def uses_prompt_embeds(self) -> bool:
|
||||
"""Returns True if the sequence group uses input embeds."""
|
||||
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
|
||||
|
||||
|
||||
class SequenceGroupMetadataDelta(
|
||||
msgspec.Struct,
|
||||
@ -1043,10 +1116,14 @@ class SequenceOutput(
|
||||
parent_seq_id: int
|
||||
output_token: int
|
||||
logprobs: dict[int, Logprob]
|
||||
output_embed: Optional[torch.Tensor] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
output_embed_shape = \
|
||||
self.output_embed.shape if self.output_embed is not None else None
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}, "
|
||||
f"output_embed.shape={output_embed_shape}"
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
|
||||
@ -201,6 +201,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
if self.prompt_adapter_config is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"prompt_adapter_config")
|
||||
if model_input.inputs_embeds is not None:
|
||||
raise ValueError("TP1DraftModelRunner has no support for "
|
||||
"inputs_embeds")
|
||||
if model_input.multi_modal_kwargs:
|
||||
raise ValueError(
|
||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||
@ -242,9 +245,16 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
|
||||
# Get model
|
||||
if use_cuda_graph:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
||||
[graph_batch_size])
|
||||
if model_input.inputs_embeds is None:
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
|
||||
if previous_hidden_states is not None:
|
||||
hidden_states = torch.cat([
|
||||
@ -281,6 +291,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
self.vllm_config):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=None,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
|
||||
@ -282,7 +282,8 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
||||
else:
|
||||
count += 1
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.append_token_id(token_id, token_logprob.logprob,
|
||||
seq_output.output_embed)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"encoder_input_tokens": self.encoder_input_tokens,
|
||||
"encoder_input_positions": self.encoder_input_positions,
|
||||
@ -172,10 +173,17 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
if (model_input.attn_metadata is not None
|
||||
and model_input.attn_metadata.prefill_metadata is None
|
||||
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[
|
||||
model_input.virtual_engine][graph_batch_size]
|
||||
if model_input.inputs_embeds is None:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
@ -189,6 +197,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
encoder_input_ids=model_input.encoder_input_tokens,
|
||||
encoder_positions=model_input.encoder_input_positions,
|
||||
|
||||
@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||
get_sampler)
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
additional fields.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
inputs_embeds: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_types: Optional[torch.Tensor] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"inputs_embeds": self.inputs_embeds,
|
||||
"input_positions": self.input_positions,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
def simple_reinit(self):
|
||||
self.input_tokens[0].clear() # type: ignore
|
||||
self.inputs_embeds = None # type: ignore
|
||||
self.input_positions[0].clear() # type: ignore
|
||||
self.token_types[0].clear() # type: ignore
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
# Input tokens and positions.
|
||||
input_tokens: Optional[List[List[int]]] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_positions: Optional[List[List[int]]] = None,
|
||||
token_types: Optional[List[List[int]]] = None,
|
||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||
@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.input_tokens[seq_id].clear()
|
||||
|
||||
self.inputs_embeds = inputs_embeds
|
||||
|
||||
if input_positions:
|
||||
self.input_positions = input_positions
|
||||
else:
|
||||
@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
else:
|
||||
self.input_tokens = input_tokens or []
|
||||
self.inputs_embeds = inputs_embeds
|
||||
self.input_positions = input_positions or []
|
||||
self.token_types = token_types or []
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.lora_index_mapping = []
|
||||
self.lora_prompt_mapping = []
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"InterDataForSeqGroup("
|
||||
f"request_id={self.request_id}, "
|
||||
f"seq_ids={self.seq_ids}, "
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"computed_block_nums={self.computed_block_nums}, "
|
||||
f"n_seqs={self.n_seqs}, "
|
||||
f"input_tokens={self.input_tokens}, "
|
||||
f"inputs_embeds.shape="
|
||||
f"{getattr(self.inputs_embeds, 'shape', None)}, "
|
||||
f"input_positions={self.input_positions}, "
|
||||
f"token_types={self.token_types}, "
|
||||
f"mrope_input_positions={self.mrope_input_positions}, "
|
||||
f"seq_lens={self.seq_lens}, "
|
||||
f"orig_seq_lens={self.orig_seq_lens}, "
|
||||
f"query_lens={self.query_lens}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"multi_modal_kwargs={self.multi_modal_kwargs}")
|
||||
|
||||
def gen_inter_data_builder(self, num_seqs: int):
|
||||
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
|
||||
request_id="",
|
||||
@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
# Compute tokens.
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
if seq_data.prompt_embeds is None:
|
||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||
prompt_embeds = None
|
||||
else:
|
||||
tokens = [0] * (seq_len - context_len)
|
||||
prompt_embeds = seq_data.get_token_embeddings(
|
||||
)[context_len:seq_len]
|
||||
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||
inter_data.token_types[seq_idx].extend(
|
||||
token_types if token_types else [])
|
||||
@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
create on-device tensors.
|
||||
"""
|
||||
# Combine and flatten intermediate data.
|
||||
input_tokens = []
|
||||
token_types = []
|
||||
input_tokens = list[int]()
|
||||
inputs_embeds_lst = list[torch.Tensor]()
|
||||
token_types = list[int]()
|
||||
for inter_data in self.inter_data_list:
|
||||
for cur_input_tokens in inter_data.input_tokens:
|
||||
input_tokens.extend(cur_input_tokens)
|
||||
for cur_token_types in inter_data.token_types:
|
||||
token_types.extend(cur_token_types)
|
||||
if inter_data.inputs_embeds is not None:
|
||||
inputs_embeds_lst.append(
|
||||
inter_data.inputs_embeds.to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device))
|
||||
inputs_embeds: Optional[torch.Tensor]
|
||||
if len(inputs_embeds_lst) == 0:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
|
||||
dtype=self.runner.model_config.dtype,
|
||||
device=self.runner.device)
|
||||
assert len(inputs_embeds) == len(input_tokens)
|
||||
|
||||
if not input_tokens:
|
||||
if not input_tokens and inputs_embeds is None:
|
||||
# This may happen when all prefill requests hit
|
||||
# prefix caching and there is no decode request.
|
||||
return self.model_input_cls()
|
||||
@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens_tensor,
|
||||
inputs_embeds=inputs_embeds,
|
||||
input_positions=input_positions_tensor,
|
||||
token_types=token_types_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.max_batchsize_to_capture = \
|
||||
self.vllm_config.compilation_config.max_capture_size
|
||||
|
||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
||||
#
|
||||
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
|
||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
input_positions = torch.zeros(max_batch_size,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
inputs_embeds = torch.zeros(
|
||||
(max_batch_size, self.model_config.get_hidden_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
if self.model_config.uses_mrope:
|
||||
input_positions = torch.tile(input_positions,
|
||||
(3, 1)).cuda(device=self.device)
|
||||
@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
self.parallel_config.pipeline_parallel_size):
|
||||
# Only rank 0 should print progress bar during capture
|
||||
cudagraph_capture_sizes = (tqdm(
|
||||
self.vllm_config.compilation_config.
|
||||
# We need to not only iterate over batch sizes, but also whether
|
||||
# to use inputs_embeds or not, hence we use the cartesian
|
||||
# product.
|
||||
cudagraph_capture_sizes = self.vllm_config.compilation_config\
|
||||
.cudagraph_capture_sizes
|
||||
cudagraph_inputs_embeds = (True, False)
|
||||
compilation_cases = itertools.product(
|
||||
cudagraph_capture_sizes,
|
||||
desc="Capturing CUDA graph shapes",
|
||||
) if get_tensor_model_parallel_rank() == 0 else
|
||||
self.vllm_config.compilation_config.
|
||||
cudagraph_capture_sizes)
|
||||
for batch_size in cudagraph_capture_sizes:
|
||||
cudagraph_inputs_embeds,
|
||||
)
|
||||
# Only rank 0 should print progress bar during capture
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
compilation_cases = tqdm(
|
||||
list(compilation_cases),
|
||||
desc="Capturing CUDA graph shapes")
|
||||
for batch_size, use_inputs_embeds in compilation_cases:
|
||||
attn_metadata = (
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size,
|
||||
@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
capture_inputs = {
|
||||
"input_ids":
|
||||
input_tokens[:batch_size],
|
||||
"inputs_embeds":
|
||||
inputs_embeds[:batch_size]
|
||||
if use_inputs_embeds else None,
|
||||
"positions":
|
||||
input_positions[..., :batch_size],
|
||||
"intermediate_inputs":
|
||||
@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
virtual_engine):
|
||||
graph_runner.capture(**capture_inputs)
|
||||
self.graph_memory_pool = graph_runner.graph.pool()
|
||||
self.graph_runners[virtual_engine][batch_size] = (
|
||||
graph_runner)
|
||||
self.graph_runners[virtual_engine][(
|
||||
batch_size, use_inputs_embeds)] = graph_runner
|
||||
|
||||
if self.lora_config:
|
||||
self._remove_dummy_loras()
|
||||
@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
use_inputs_embeds = model_input.inputs_embeds is not None
|
||||
model_executable = self.graph_runners[virtual_engine][(
|
||||
graph_batch_size, use_inputs_embeds)]
|
||||
if previous_hidden_states is not None:
|
||||
previous_hidden_states = torch.cat([
|
||||
previous_hidden_states,
|
||||
@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
self.vllm_config, virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
inputs_embeds=model_input.inputs_embeds,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||
@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
assert isinstance(self.sampler, Sampler)
|
||||
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = True
|
||||
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
output.model_forward_time = (orig_model_forward_time +
|
||||
model_forward_time)
|
||||
|
||||
if model_input.inputs_embeds is not None:
|
||||
self.sampler.include_gpu_probs_tensor = \
|
||||
orig_include_gpu_probs_tensor
|
||||
if output.sampled_token_ids is not None:
|
||||
output.sampled_token_embeds = self.model.get_input_embeddings(
|
||||
output.sampled_token_ids.squeeze(1))
|
||||
|
||||
for token_embed, sequence_group_output in zip(
|
||||
output.sampled_token_embeds, output.outputs):
|
||||
assert len(sequence_group_output.samples) == 1
|
||||
sequence_group_output.samples[0].output_embed = token_embed
|
||||
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
assert model_input.sampling_metadata is not None
|
||||
@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
|
||||
def capture(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_inputs: Optional[IntermediateTensors],
|
||||
kv_caches: List[torch.Tensor],
|
||||
@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
|
||||
for _ in range(_NUM_WARMUP_ITERS):
|
||||
self.model(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
|
||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||
output_hidden_or_intermediate_states = self.model(
|
||||
input_ids=input_ids,
|
||||
**({
|
||||
"inputs_embeds": inputs_embeds,
|
||||
} if inputs_embeds is not None else {}),
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_inputs,
|
||||
**kwargs,
|
||||
@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
|
||||
self.input_buffers = {
|
||||
"input_ids":
|
||||
input_ids,
|
||||
**({
|
||||
"inputs_embeds": inputs_embeds,
|
||||
} if inputs_embeds is not None else {}),
|
||||
"positions":
|
||||
positions,
|
||||
"kv_caches":
|
||||
@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
**kwargs,
|
||||
@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
|
||||
# so the shape is not padded, we need to copy partial only
|
||||
self.input_buffers["positions"][:positions.shape[0]].copy_(
|
||||
positions, non_blocking=True)
|
||||
if inputs_embeds is not None:
|
||||
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
|
||||
inputs_embeds, non_blocking=True)
|
||||
|
||||
if self.backend_name != "NO_ATTENTION":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
|
||||
@ -84,10 +84,17 @@ class PoolingModelRunner(
|
||||
# explore how to leverage it.
|
||||
if (prefill_meta is None and decode_meta is not None
|
||||
and decode_meta.use_cuda_graph):
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = self.graph_runners[virtual_engine][
|
||||
graph_batch_size]
|
||||
if model_input.inputs_embeds is None:
|
||||
assert model_input.input_tokens is not None
|
||||
graph_batch_size = model_input.input_tokens.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, False)])
|
||||
else:
|
||||
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||
model_executable = (
|
||||
self.graph_runners[model_input.virtual_engine][(
|
||||
graph_batch_size, True)])
|
||||
else:
|
||||
model_executable = self.model
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user