mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 14:26:07 +08:00
[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
9e2de9b9e9
commit
cc2a77d7f1
@ -787,7 +787,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[torch.Tensor]],
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
@ -809,16 +809,18 @@ class VllmRunner:
|
|||||||
if audios is not None and (audio := audios[i]) is not None:
|
if audios is not None and (audio := audios[i]) is not None:
|
||||||
multi_modal_data["audio"] = audio
|
multi_modal_data["audio"] = audio
|
||||||
|
|
||||||
inputs.append(
|
text_prompt_kwargs = {
|
||||||
TextPrompt(prompt=prompt,
|
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
|
||||||
multi_modal_data=multi_modal_data
|
prompt,
|
||||||
if multi_modal_data else None))
|
"multi_modal_data": multi_modal_data or None
|
||||||
|
}
|
||||||
|
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||||||
|
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[torch.Tensor]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
@ -844,7 +846,7 @@ class VllmRunner:
|
|||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = list(sample.token_ids)
|
output_ids = list(sample.token_ids)
|
||||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||||
req_sample_output_strs.append(prompt_str + output_str)
|
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@ -911,7 +913,7 @@ class VllmRunner:
|
|||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: Union[list[str], list[torch.Tensor]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
|
|||||||
@ -2,16 +2,18 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from typing import Optional
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest # noqa
|
import pytest # noqa
|
||||||
|
import torch
|
||||||
from torch import Use # noqa
|
from torch import Use # noqa
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.core.interfaces import AllocStatus
|
from vllm.core.interfaces import AllocStatus
|
||||||
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
from vllm.core.scheduler import Scheduler, SchedulingBudget
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import SequenceGroup
|
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||||
|
|
||||||
from .utils import (append_new_token, append_new_token_seq,
|
from .utils import (append_new_token, append_new_token_seq,
|
||||||
append_new_token_seq_group, create_dummy_prompt,
|
append_new_token_seq_group, create_dummy_prompt,
|
||||||
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
|
|||||||
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
), "A partial prefix of C (4 tokens) should be prefilled, with the "
|
||||||
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
|
||||||
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
|
||||||
|
"""
|
||||||
|
Test that the scheduler does not schedule batches with prompt tokens and
|
||||||
|
prompt embeddings co-mingled.
|
||||||
|
"""
|
||||||
|
block_size = 2
|
||||||
|
max_seq_group = 3
|
||||||
|
scheduler = initialize_scheduler(
|
||||||
|
block_size=block_size,
|
||||||
|
num_cpu_blocks=16,
|
||||||
|
num_gpu_blocks=16,
|
||||||
|
max_num_seqs=max_seq_group,
|
||||||
|
max_model_len=100,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# the odd indexed inputs should be passed in via embeddings,
|
||||||
|
# evens via token_ids
|
||||||
|
seq_length = 7
|
||||||
|
embedding_size = 5
|
||||||
|
num_seqs = 11
|
||||||
|
seq_tokens: list[list[int]] = []
|
||||||
|
seq_embeds: list[Optional[torch.Tensor]] = []
|
||||||
|
for i in range(num_seqs):
|
||||||
|
if i % 2:
|
||||||
|
seq_tokens.append(list(range(seq_length)))
|
||||||
|
seq_embeds.append(None)
|
||||||
|
else:
|
||||||
|
seq_tokens.append([0] * seq_length)
|
||||||
|
seq_embeds.append(torch.rand(embedding_size))
|
||||||
|
|
||||||
|
seq_and_seq_groups = [
|
||||||
|
create_dummy_prompt(f"{i}",
|
||||||
|
prompt_tokens=seq_tokens[i],
|
||||||
|
prompt_embeds=seq_embeds[i],
|
||||||
|
block_size=block_size)
|
||||||
|
for i in range(len(seq_tokens))
|
||||||
|
]
|
||||||
|
|
||||||
|
for _, seq_group in seq_and_seq_groups:
|
||||||
|
scheduler.add_seq_group(seq_group)
|
||||||
|
|
||||||
|
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
|
||||||
|
unfinished_seq_groups = [
|
||||||
|
seq_group for _, seq_group in seq_and_seq_groups
|
||||||
|
if not seq_group.is_finished()
|
||||||
|
]
|
||||||
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
assert len(out.scheduled_seq_groups) > 0
|
||||||
|
batch_is_prompt_embeds = out.scheduled_seq_groups[
|
||||||
|
0].seq_group.uses_prompt_embeds()
|
||||||
|
expected_scheduled_seq_groups = [
|
||||||
|
seq_group for seq_group in unfinished_seq_groups
|
||||||
|
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
|
||||||
|
]
|
||||||
|
|
||||||
|
# We should have as many scheduled groups as possible, without mixing
|
||||||
|
assert len(out.scheduled_seq_groups) == min(
|
||||||
|
max_seq_group, len(expected_scheduled_seq_groups))
|
||||||
|
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
|
||||||
|
batch_is_prompt_embeds
|
||||||
|
for scheduled_seq_group in out.scheduled_seq_groups)
|
||||||
|
|
||||||
|
# Finish the scheduled groups
|
||||||
|
for scheduled_seq_group in out.scheduled_seq_groups:
|
||||||
|
for seq in scheduled_seq_group.seq_group.seqs:
|
||||||
|
seq.status = SequenceStatus.FINISHED_STOPPED
|
||||||
|
scheduler.free_finished_seq_groups()
|
||||||
|
|||||||
@ -5,9 +5,11 @@ from collections import defaultdict
|
|||||||
from collections.abc import Sequence as GenericSequence
|
from collections.abc import Sequence as GenericSequence
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.inputs import EncoderDecoderInputs, token_inputs
|
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
|
||||||
SequenceGroupMetadata)
|
SequenceGroupMetadata)
|
||||||
@ -19,6 +21,7 @@ def create_dummy_prompt(
|
|||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_tokens: Optional[list[int]] = None,
|
prompt_tokens: Optional[list[int]] = None,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
min_tokens: int = 0,
|
min_tokens: int = 0,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
) -> tuple[Sequence, SequenceGroup]:
|
) -> tuple[Sequence, SequenceGroup]:
|
||||||
@ -31,9 +34,13 @@ def create_dummy_prompt(
|
|||||||
prompt_tokens = list(range(prompt_length))
|
prompt_tokens = list(range(prompt_length))
|
||||||
|
|
||||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||||
|
inputs = token_inputs(
|
||||||
|
prompt_token_ids=prompt_tokens,
|
||||||
|
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
|
||||||
|
prompt_embeds=prompt_embeds)
|
||||||
prompt = Sequence(
|
prompt = Sequence(
|
||||||
int(request_id),
|
int(request_id),
|
||||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
inputs=inputs,
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
)
|
)
|
||||||
seq_group = SequenceGroup(
|
seq_group = SequenceGroup(
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
|||||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv(
|
||||||
|
"VLLM_USE_V1") == "0" else None
|
||||||
|
prompt_token_ids = []
|
||||||
|
for prompt in example_prompts:
|
||||||
|
token_ids = hf_model.tokenizer(prompt,
|
||||||
|
return_tensors="pt").input_ids.to(
|
||||||
|
hf_model.model.device)
|
||||||
|
prompt_token_ids.append(token_ids)
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
prompt_embeds.append(hf_model.model.get_input_embeddings()(
|
||||||
|
token_ids).squeeze(0))
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
tokenizer_name=model_info.tokenizer or model,
|
tokenizer_name=model_info.tokenizer or model,
|
||||||
@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
|||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
|
||||||
|
prompt_embeds, max_tokens, num_logprobs)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
if prompt_embeds is not None:
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=vllm_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs_from_embeds,
|
||||||
|
name_0="vllm",
|
||||||
|
name_1="vllm_from_embeds",
|
||||||
|
)
|
||||||
|
|
||||||
if use_rocm_aiter:
|
if use_rocm_aiter:
|
||||||
# this is to ensure that vllm engine
|
# this is to ensure that vllm engine
|
||||||
# has deallocated the memory before running the next
|
# has deallocated the memory before running the next
|
||||||
|
|||||||
@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module():
|
|||||||
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||||
def test_prepare_prompt(batch_size):
|
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||||
|
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
|
||||||
|
if use_prompt_embeds:
|
||||||
|
# Prompt Embeddings is only currently supported on V0
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
model_runner = _create_model_runner(
|
model_runner = _create_model_runner(
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
max_num_batched_tokens=100000,
|
max_num_batched_tokens=100000,
|
||||||
@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size):
|
|||||||
seq_lens: list[int] = []
|
seq_lens: list[int] = []
|
||||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||||
block_tables = {0: [1]}
|
block_tables = {0: [1]}
|
||||||
|
expected_input_embeds_len = 0
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
if use_prompt_embeds:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=[0] * seq_len,
|
||||||
|
prompt_embeds=torch.rand(seq_len, 10),
|
||||||
|
)
|
||||||
|
expected_input_embeds_len += seq_len
|
||||||
|
else:
|
||||||
|
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
|
||||||
|
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
seq_group_metadata_list)
|
seq_group_metadata_list)
|
||||||
input_tokens = model_input.input_tokens
|
input_tokens = model_input.input_tokens
|
||||||
input_positions = model_input.input_positions
|
input_positions = model_input.input_positions
|
||||||
|
input_embeds = model_input.inputs_embeds
|
||||||
attn_metadata = model_input.attn_metadata
|
attn_metadata = model_input.attn_metadata
|
||||||
return_seq_lens = model_input.seq_lens
|
return_seq_lens = model_input.seq_lens
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size):
|
|||||||
|
|
||||||
assert len(input_tokens) == sum(seq_lens)
|
assert len(input_tokens) == sum(seq_lens)
|
||||||
assert len(input_positions) == sum(seq_lens)
|
assert len(input_positions) == sum(seq_lens)
|
||||||
torch.testing.assert_close(input_tokens, input_positions)
|
if expected_input_embeds_len == 0:
|
||||||
|
torch.testing.assert_close(input_tokens, input_positions)
|
||||||
|
assert input_embeds is None
|
||||||
|
else:
|
||||||
|
assert len(input_embeds) == expected_input_embeds_len
|
||||||
|
|
||||||
sampling_metadata = SamplingMetadata.prepare(
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
seq_group_metadata_list,
|
seq_group_metadata_list,
|
||||||
@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size):
|
|||||||
torch.testing.assert_close(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
|
||||||
def test_prepare_decode_cuda_graph(batch_size):
|
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
|
||||||
|
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
|
||||||
|
if use_prompt_embeds:
|
||||||
|
# Prompt Embeddings is only currently supported on V0
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
model_runner = _create_model_runner(
|
model_runner = _create_model_runner(
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
context_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
context_lens.append(context_len)
|
context_lens.append(context_len)
|
||||||
seq_data = SequenceData.from_seqs(range(context_len))
|
if use_prompt_embeds:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=[0] * context_len,
|
||||||
|
prompt_embeds=torch.rand(context_len, 10),
|
||||||
|
)
|
||||||
|
output_embed = torch.rand(10)
|
||||||
|
else:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=range(context_len))
|
||||||
|
output_embed = None
|
||||||
seq_data.update_num_computed_tokens(context_len)
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
# Append one token ID since prefill is finished.
|
# Append one token ID since prefill is finished.
|
||||||
seq_data.append_token_id(1, 0)
|
seq_data.append_token_id(1, 0, output_embed)
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=False,
|
is_prompt=False,
|
||||||
@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
|
|
||||||
model_input = model_runner._prepare_model_input_tensors(
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list)
|
||||||
input_tokens, input_positions, attn_metadata, slot_mapping = (
|
input_tokens = model_input.input_tokens
|
||||||
model_input.input_tokens, model_input.input_positions,
|
input_positions = model_input.input_positions
|
||||||
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
|
input_embeds = model_input.inputs_embeds
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
|
||||||
assert len(slot_mapping) == len(input_tokens)
|
assert len(slot_mapping) == len(input_tokens)
|
||||||
|
|
||||||
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
|
||||||
@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
# block table's first index corresponds to each batch, meaning in
|
# block table's first index corresponds to each batch, meaning in
|
||||||
# decoding it is each token.
|
# decoding it is each token.
|
||||||
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
|
||||||
# Block table's second dim correspondsd to each token's block number.
|
# Block table's second dim corresponds to each token's block number.
|
||||||
# It is padded up to
|
# It is padded up to
|
||||||
assert attn_metadata.block_tables.shape[1] == (
|
assert attn_metadata.block_tables.shape[1] == (
|
||||||
model_runner.get_max_block_per_batch())
|
model_runner.get_max_block_per_batch())
|
||||||
@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
|
|
||||||
assert len(input_tokens) == expected_bs
|
assert len(input_tokens) == expected_bs
|
||||||
assert len(input_positions) == expected_bs
|
assert len(input_positions) == expected_bs
|
||||||
torch.allclose(input_tokens, input_positions)
|
if use_prompt_embeds:
|
||||||
|
expected_input_embeds_length = start_loc[-1]
|
||||||
|
assert len(input_embeds) == expected_input_embeds_length
|
||||||
|
assert expected_input_embeds_length <= expected_bs
|
||||||
|
else:
|
||||||
|
assert input_embeds is None
|
||||||
|
|
||||||
# Verify Sampling
|
# Verify Sampling
|
||||||
expected_selected_token_indices = []
|
expected_selected_token_indices = []
|
||||||
@ -266,25 +307,27 @@ def test_empty_seq_group():
|
|||||||
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
seq_group_metadata_list: list[SequenceGroupMetadata] = []
|
||||||
model_input = model_runner._prepare_model_input_tensors(
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list)
|
||||||
input_tokens, input_positions, attn_metadata = (
|
|
||||||
model_input.input_tokens,
|
input_tokens = model_input.input_tokens
|
||||||
model_input.input_positions,
|
input_positions = model_input.input_positions
|
||||||
model_input.attn_metadata,
|
attn_metadata = model_input.attn_metadata
|
||||||
)
|
|
||||||
assert input_tokens is None
|
assert input_tokens is None
|
||||||
assert input_positions is None
|
assert input_positions is None
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
|
|
||||||
model_input = model_runner._prepare_model_input_tensors(
|
model_input = model_runner._prepare_model_input_tensors(
|
||||||
seq_group_metadata_list)
|
seq_group_metadata_list)
|
||||||
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
|
|
||||||
model_input.input_tokens,
|
input_tokens = model_input.input_tokens
|
||||||
model_input.input_positions,
|
input_positions = model_input.input_positions
|
||||||
model_input.attn_metadata,
|
input_embeds = model_input.inputs_embeds
|
||||||
model_input.seq_lens,
|
attn_metadata = model_input.attn_metadata
|
||||||
)
|
return_seq_lens = model_input.seq_lens
|
||||||
|
|
||||||
assert input_tokens is None
|
assert input_tokens is None
|
||||||
assert input_positions is None
|
assert input_positions is None
|
||||||
|
assert input_embeds is None
|
||||||
assert attn_metadata is None
|
assert attn_metadata is None
|
||||||
assert return_seq_lens is None
|
assert return_seq_lens is None
|
||||||
|
|
||||||
@ -299,9 +342,15 @@ def distributed_init():
|
|||||||
ensure_model_parallel_initialized(1, 1)
|
ensure_model_parallel_initialized(1, 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
|
||||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||||
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
|
||||||
|
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
|
||||||
|
distributed_init, monkeypatch):
|
||||||
|
if use_prompt_embeds:
|
||||||
|
# Prompt Embeddings is only currently supported on V0
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "0")
|
||||||
|
|
||||||
model_runner = _create_model_runner(
|
model_runner = _create_model_runner(
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
block_tables = {0: [1]}
|
block_tables = {0: [1]}
|
||||||
prefill_batch_size = batch_size // 2
|
prefill_batch_size = batch_size // 2
|
||||||
decode_batch_size = batch_size - prefill_batch_size
|
decode_batch_size = batch_size - prefill_batch_size
|
||||||
|
expected_input_embeds_len = 0
|
||||||
for i in range(prefill_batch_size):
|
for i in range(prefill_batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
seq_len = i % (model_runner.block_size - 1) + 1
|
seq_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_lens.append(seq_len)
|
seq_lens.append(seq_len)
|
||||||
seq_data = SequenceData.from_seqs(range(seq_len))
|
if use_prompt_embeds:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=[0] * seq_len,
|
||||||
|
prompt_embeds=torch.rand(seq_len, 10),
|
||||||
|
)
|
||||||
|
expected_input_embeds_len += seq_len
|
||||||
|
else:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=range(seq_len), )
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
is_prompt=True,
|
is_prompt=True,
|
||||||
@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
for i in range(prefill_batch_size, batch_size):
|
for i in range(prefill_batch_size, batch_size):
|
||||||
# make sure all tokens fit into one block
|
# make sure all tokens fit into one block
|
||||||
context_len = i % (model_runner.block_size - 1) + 1
|
context_len = i % (model_runner.block_size - 1) + 1
|
||||||
seq_data = SequenceData.from_seqs(range(context_len))
|
if use_prompt_embeds:
|
||||||
seq_data.append_token_id(1, 0)
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=[0] * context_len,
|
||||||
|
prompt_embeds=torch.rand(context_len, 10),
|
||||||
|
)
|
||||||
|
output_embed = torch.rand(10)
|
||||||
|
# This also iterates the expected input_embeds, because the model
|
||||||
|
# needs both the input and output embeddings passed into together
|
||||||
|
expected_input_embeds_len += 1
|
||||||
|
else:
|
||||||
|
seq_data = SequenceData.from_seqs(
|
||||||
|
prompt_token_ids=range(context_len), )
|
||||||
|
output_embed = None
|
||||||
|
assert len(seq_data.prompt_token_ids) == context_len
|
||||||
|
seq_data.append_token_id(1, 0, output_embed)
|
||||||
seq_data.update_num_computed_tokens(context_len)
|
seq_data.update_num_computed_tokens(context_len)
|
||||||
seq_group_metadata = SequenceGroupMetadata(
|
seq_group_metadata = SequenceGroupMetadata(
|
||||||
request_id=f"test_{i}",
|
request_id=f"test_{i}",
|
||||||
@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
decode_metadata_list.append(seq_group_metadata)
|
decode_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
|
||||||
(input_tokens, input_positions, attn_metadata) = (
|
|
||||||
model_input.input_tokens,
|
input_tokens = model_input.input_tokens
|
||||||
model_input.input_positions,
|
input_positions = model_input.input_positions
|
||||||
model_input.attn_metadata,
|
input_embeds = model_input.inputs_embeds
|
||||||
)
|
attn_metadata = model_input.attn_metadata
|
||||||
|
|
||||||
prefill_meta_actual = attn_metadata.prefill_metadata
|
prefill_meta_actual = attn_metadata.prefill_metadata
|
||||||
decode_meta_actual = attn_metadata.decode_metadata
|
decode_meta_actual = attn_metadata.decode_metadata
|
||||||
@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
|||||||
assert attn_metadata.num_prefills == prefill_batch_size
|
assert attn_metadata.num_prefills == prefill_batch_size
|
||||||
assert attn_metadata.num_decode_tokens == decode_batch_size
|
assert attn_metadata.num_decode_tokens == decode_batch_size
|
||||||
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
|
||||||
|
if expected_input_embeds_len == 0:
|
||||||
|
assert input_embeds is None
|
||||||
|
else:
|
||||||
|
assert len(input_embeds) == expected_input_embeds_len
|
||||||
|
|
||||||
# Verify attn metadata is consistent. We don't need to test individual
|
# Verify attn metadata is consistent. We don't need to test individual
|
||||||
# values here because they are tested above.
|
# values here because they are tested above.
|
||||||
|
|||||||
@ -367,9 +367,17 @@ class FlashInferState(AttentionState):
|
|||||||
# scheduled while CUDA graph mode is enabled. We don't run graph in that
|
# scheduled while CUDA graph mode is enabled. We don't run graph in that
|
||||||
# case.
|
# case.
|
||||||
if use_cuda_graph and is_decode:
|
if use_cuda_graph and is_decode:
|
||||||
batch_size = model_input.input_tokens.shape[0]
|
if model_input.inputs_embeds is None:
|
||||||
state = (self.runner.graph_runners[model_input.virtual_engine]
|
batch_size = model_input.input_tokens.shape[0]
|
||||||
[batch_size].attn_state)
|
state = (
|
||||||
|
self.runner.graph_runners[model_input.virtual_engine][(
|
||||||
|
batch_size, False)].attn_state)
|
||||||
|
else:
|
||||||
|
batch_size = model_input.inputs_embeds.shape[0]
|
||||||
|
state = (
|
||||||
|
self.runner.graph_runners[model_input.virtual_engine][(
|
||||||
|
batch_size, True)].attn_state)
|
||||||
|
|
||||||
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
|
||||||
)
|
)
|
||||||
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()
|
||||||
|
|||||||
@ -1071,6 +1071,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
ignored_seq_groups: List[SequenceGroup] = []
|
ignored_seq_groups: List[SequenceGroup] = []
|
||||||
seq_groups: List[ScheduledSequenceGroup] = []
|
seq_groups: List[ScheduledSequenceGroup] = []
|
||||||
|
using_prompt_embeds: bool = False
|
||||||
|
|
||||||
waiting_queue = self.waiting
|
waiting_queue = self.waiting
|
||||||
|
|
||||||
@ -1138,6 +1139,15 @@ class Scheduler:
|
|||||||
waiting_queue.popleft()
|
waiting_queue.popleft()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# We cannot mix sequence groups that use prompt embeds and
|
||||||
|
# those that do not.
|
||||||
|
if len(seq_groups) == 0:
|
||||||
|
using_prompt_embeds = seq_group.uses_prompt_embeds()
|
||||||
|
if using_prompt_embeds != seq_group.uses_prompt_embeds():
|
||||||
|
leftover_waiting_sequences.appendleft(seq_group)
|
||||||
|
waiting_queue.popleft()
|
||||||
|
continue
|
||||||
|
|
||||||
lora_int_id = 0
|
lora_int_id = 0
|
||||||
if self.lora_enabled:
|
if self.lora_enabled:
|
||||||
lora_int_id = seq_group.lora_int_id
|
lora_int_id = seq_group.lora_int_id
|
||||||
@ -1295,17 +1305,39 @@ class Scheduler:
|
|||||||
|
|
||||||
# Merge lists
|
# Merge lists
|
||||||
num_prefill_groups = len(prefills.seq_groups)
|
num_prefill_groups = len(prefills.seq_groups)
|
||||||
|
ignored_seq_groups_for_embeds = list[SequenceGroup]()
|
||||||
if num_prefill_groups > 0:
|
if num_prefill_groups > 0:
|
||||||
scheduled_seq_groups = prefills.seq_groups
|
scheduled_seq_groups = prefills.seq_groups
|
||||||
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
|
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
|
||||||
|
ignored_seq_groups_for_embeds.clear()
|
||||||
else:
|
else:
|
||||||
scheduled_seq_groups = running_scheduled.decode_seq_groups
|
scheduled_seq_groups = running_scheduled.decode_seq_groups
|
||||||
|
if len(scheduled_seq_groups) > 0:
|
||||||
|
using_prompt_embeds = scheduled_seq_groups[
|
||||||
|
0].seq_group.uses_prompt_embeds()
|
||||||
|
ignored_seq_groups_for_embeds.clear()
|
||||||
|
indices_ignored = list[int]()
|
||||||
|
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
|
||||||
|
if using_prompt_embeds !=\
|
||||||
|
schedule_seq_group.seq_group.uses_prompt_embeds():
|
||||||
|
ignored_seq_groups_for_embeds.append(
|
||||||
|
schedule_seq_group.seq_group)
|
||||||
|
indices_ignored.append(i)
|
||||||
|
if len(ignored_seq_groups_for_embeds) > 0:
|
||||||
|
scheduled_seq_groups = [
|
||||||
|
group for i, group in enumerate(scheduled_seq_groups)
|
||||||
|
if i not in indices_ignored
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
ignored_seq_groups_for_embeds.clear()
|
||||||
|
|
||||||
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
|
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
|
||||||
|
|
||||||
blocks_to_copy = running_scheduled.blocks_to_copy
|
blocks_to_copy = running_scheduled.blocks_to_copy
|
||||||
blocks_to_copy.extend(swapped_in.blocks_to_copy)
|
blocks_to_copy.extend(swapped_in.blocks_to_copy)
|
||||||
|
|
||||||
ignored_seq_groups = prefills.ignored_seq_groups
|
ignored_seq_groups = prefills.ignored_seq_groups
|
||||||
|
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
|
||||||
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
|
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
|
||||||
|
|
||||||
return SchedulerOutputs(
|
return SchedulerOutputs(
|
||||||
|
|||||||
@ -489,6 +489,14 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
if (isinstance(prompt, dict)
|
||||||
|
and prompt.get("prompt_embeds", None) is not None
|
||||||
|
and not prompt.get("prompt_token_ids", None)):
|
||||||
|
# We use the -2 dimension (instead of 0) in case a batched input
|
||||||
|
# of batch size 1 is passed in.
|
||||||
|
prompt["prompt_token_ids"] = [0
|
||||||
|
] * prompt["prompt_embeds"].shape[-2]
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
tokenizer = await self.get_tokenizer_async(lora_request)
|
tokenizer = await self.get_tokenizer_async(lora_request)
|
||||||
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
self._validate_token_prompt(prompt, tokenizer=tokenizer)
|
||||||
|
|||||||
@ -753,6 +753,12 @@ class LLMEngine:
|
|||||||
if arrival_time is None:
|
if arrival_time is None:
|
||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
|
if (isinstance(prompt, dict)
|
||||||
|
and prompt.get("prompt_embeds", None) is not None
|
||||||
|
and not prompt.get("prompt_token_ids", None)):
|
||||||
|
seq_len = prompt["prompt_embeds"].shape[0]
|
||||||
|
prompt["prompt_token_ids"] = [0] * seq_len
|
||||||
|
|
||||||
if self.tokenizer is not None:
|
if self.tokenizer is not None:
|
||||||
self._validate_token_prompt(
|
self._validate_token_prompt(
|
||||||
prompt,
|
prompt,
|
||||||
@ -1267,11 +1273,13 @@ class LLMEngine:
|
|||||||
if self.scheduler_config.is_multi_step:
|
if self.scheduler_config.is_multi_step:
|
||||||
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
is_prefill_append = seq.data.get_num_uncomputed_tokens(
|
||||||
) == 0
|
) == 0
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||||
|
sample.output_embed)
|
||||||
if not is_prefill_append:
|
if not is_prefill_append:
|
||||||
seq_group.update_num_computed_tokens(1)
|
seq_group.update_num_computed_tokens(1)
|
||||||
else:
|
else:
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||||
|
sample.output_embed)
|
||||||
|
|
||||||
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
|
||||||
"""Performs one decoding iteration and returns newly generated results.
|
"""Performs one decoding iteration and returns newly generated results.
|
||||||
@ -2032,10 +2040,12 @@ class LLMEngine:
|
|||||||
tokenizer = (None if self.tokenizer is None else
|
tokenizer = (None if self.tokenizer is None else
|
||||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||||
|
|
||||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
|
||||||
if not prompt_ids:
|
if not prompt_ids:
|
||||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||||
pass # Mllama may have empty encoder inputs for text-only data
|
pass # Mllama may have empty encoder inputs for text-only data
|
||||||
|
if prompt_inputs["type"] == "embeds":
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||||
|
|
||||||
|
|||||||
@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
output_token_ids = [sample.output_token for sample in valid_samples]
|
output_token_ids = [sample.output_token for sample in valid_samples]
|
||||||
output_logprobs = [sample.logprobs for sample in valid_samples]
|
output_logprobs = [sample.logprobs for sample in valid_samples]
|
||||||
|
output_embeds = [sample.output_embed for sample in valid_samples]
|
||||||
|
|
||||||
# Truncate to max_tokens if necessary.
|
# Truncate to max_tokens if necessary.
|
||||||
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
|
||||||
@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
|
||||||
# Incrementally append tokens to the sequence, as if we had only one new
|
# Incrementally append tokens to the sequence, as if we had only one new
|
||||||
# token.
|
# token.
|
||||||
for output_token_id, output_logprob in zip(output_token_ids,
|
for output_token_id, output_logprob, output_embed in zip(
|
||||||
output_logprobs):
|
output_token_ids, output_logprobs, output_embeds):
|
||||||
seq.append_token_id(
|
seq.append_token_id(
|
||||||
token_id=output_token_id,
|
token_id=output_token_id,
|
||||||
logprobs=output_logprob,
|
logprobs=output_logprob,
|
||||||
|
token_embed=output_embed,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_prefill_sampled_token:
|
if is_prefill_sampled_token:
|
||||||
|
|||||||
@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
|
|||||||
sample = outputs.samples[0]
|
sample = outputs.samples[0]
|
||||||
seq = seq_group.first_seq
|
seq = seq_group.first_seq
|
||||||
if not is_async:
|
if not is_async:
|
||||||
seq.append_token_id(sample.output_token, sample.logprobs)
|
seq.append_token_id(sample.output_token, sample.logprobs,
|
||||||
|
sample.output_embed)
|
||||||
if sampling_params.detokenize and self.detokenizer:
|
if sampling_params.detokenize and self.detokenizer:
|
||||||
new_char_count = self.detokenizer.decode_sequence_inplace(
|
new_char_count = self.detokenizer.decode_sequence_inplace(
|
||||||
seq, sampling_params)
|
seq, sampling_params)
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
|
||||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||||
InputRegistry)
|
InputRegistry)
|
||||||
@ -21,7 +21,9 @@ __all__ = [
|
|||||||
"SingletonPrompt",
|
"SingletonPrompt",
|
||||||
"ExplicitEncoderDecoderPrompt",
|
"ExplicitEncoderDecoderPrompt",
|
||||||
"TokenInputs",
|
"TokenInputs",
|
||||||
|
"EmbedsInputs",
|
||||||
"token_inputs",
|
"token_inputs",
|
||||||
|
"embeds_inputs",
|
||||||
"DecoderOnlyInputs",
|
"DecoderOnlyInputs",
|
||||||
"EncoderDecoderInputs",
|
"EncoderDecoderInputs",
|
||||||
"ProcessorInputs",
|
"ProcessorInputs",
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
class EmbedsPrompt(TypedDict):
|
||||||
|
"""Schema for a prompt provided via token embeddings."""
|
||||||
|
|
||||||
|
prompt_embeds: torch.Tensor
|
||||||
|
"""The embeddings of the prompt."""
|
||||||
|
|
||||||
|
|
||||||
|
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||||
"""
|
"""
|
||||||
Set of possible schemas for a single prompt:
|
Set of possible schemas for a single prompt:
|
||||||
|
|
||||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||||
- A tokenized prompt (:class:`TokensPrompt`)
|
- A tokenized prompt (:class:`TokensPrompt`)
|
||||||
|
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||||
|
|
||||||
Note that "singleton" is as opposed to a data structure
|
Note that "singleton" is as opposed to a data structure
|
||||||
which encapsulates multiple prompts, i.e. of the sort
|
which encapsulates multiple prompts, i.e. of the sort
|
||||||
@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
|
|||||||
|
|
||||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||||
- A tokenized prompt (:class:`TokensPrompt`)
|
- A tokenized prompt (:class:`TokensPrompt`)
|
||||||
|
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||||
- A single data structure containing both an encoder and a decoder prompt
|
- A single data structure containing both an encoder and a decoder prompt
|
||||||
(:class:`ExplicitEncoderDecoderPrompt`)
|
(:class:`ExplicitEncoderDecoderPrompt`)
|
||||||
"""
|
"""
|
||||||
@ -176,7 +186,27 @@ def token_inputs(
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
|
class EmbedsInputs(TypedDict):
|
||||||
|
"""Represents embeddings-based inputs."""
|
||||||
|
|
||||||
|
type: Literal["embeds"]
|
||||||
|
"""The type of inputs."""
|
||||||
|
|
||||||
|
prompt_embeds: torch.Tensor
|
||||||
|
"""The embeddings of the prompt."""
|
||||||
|
|
||||||
|
|
||||||
|
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
|
||||||
|
"""Construct :class:`EmbedsInputs` from optional values."""
|
||||||
|
inputs = EmbedsInputs(
|
||||||
|
type="embeds",
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||||
"""
|
"""
|
||||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||||
passed to the model executor.
|
passed to the model executor.
|
||||||
@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
|
|||||||
"""The inputs for the decoder portion."""
|
"""The inputs for the decoder portion."""
|
||||||
|
|
||||||
|
|
||||||
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
|
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||||
"""
|
"""
|
||||||
A processed :class:`SingletonPrompt` which can be passed to
|
A processed :class:`SingletonPrompt` which can be passed to
|
||||||
:class:`vllm.sequence.Sequence`.
|
:class:`vllm.sequence.Sequence`.
|
||||||
|
|||||||
@ -6,8 +6,9 @@ from typing_extensions import TypeIs
|
|||||||
|
|
||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
|
||||||
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
ProcessorInputs, PromptType, SingletonInputs,
|
||||||
|
SingletonPrompt, TextPrompt, TokensPrompt)
|
||||||
|
|
||||||
|
|
||||||
class ParsedText(TypedDict):
|
class ParsedText(TypedDict):
|
||||||
@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
|
|||||||
content: TokensPrompt
|
content: TokensPrompt
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedEmbedsPrompt(TypedDict):
|
||||||
|
type: Literal['embeds']
|
||||||
|
content: EmbedsPrompt
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
def parse_singleton_prompt(
|
def parse_singleton_prompt(
|
||||||
prompt: SingletonPrompt,
|
prompt: SingletonPrompt,
|
||||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||||
|
ParsedEmbedsPrompt]:
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
return ParsedStrPrompt(type="str", content=prompt)
|
return ParsedStrPrompt(type="str", content=prompt)
|
||||||
elif isinstance(prompt, dict):
|
elif isinstance(prompt, dict):
|
||||||
if "prompt_token_ids" in prompt:
|
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||||
return ParsedTokensPrompt(type="tokens",
|
# Pyright does succeed.
|
||||||
content=prompt) # type: ignore
|
if "prompt_embeds" in prompt:
|
||||||
|
return ParsedEmbedsPrompt(
|
||||||
|
type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||||
|
elif "prompt_token_ids" in prompt:
|
||||||
|
return ParsedTokensPrompt(
|
||||||
|
type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||||
elif "prompt" in prompt:
|
elif "prompt" in prompt:
|
||||||
return ParsedTextPrompt(type="text", content=prompt)
|
return ParsedTextPrompt(type="text", content=prompt)
|
||||||
|
raise TypeError(
|
||||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||||
|
|
||||||
|
|
||||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
|
||||||
|
return isinstance(prompt, dict) and "prompt_embeds" in prompt
|
||||||
|
|
||||||
|
|
||||||
def is_explicit_encoder_decoder_prompt(
|
def is_explicit_encoder_decoder_prompt(
|
||||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||||
|
|
||||||
|
|
||||||
|
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
|
||||||
|
return isinstance(inputs, dict) and inputs["type"] == "embeds"
|
||||||
|
|
||||||
|
|
||||||
def split_enc_dec_inputs(
|
def split_enc_dec_inputs(
|
||||||
inputs: ProcessorInputs,
|
inputs: ProcessorInputs,
|
||||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
|
|||||||
|
|
||||||
from typing_extensions import assert_never
|
from typing_extensions import assert_never
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
ProcessorInputs, PromptType, SingletonInputs,
|
||||||
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
||||||
|
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
||||||
|
ParsedTokensPrompt, is_embeds_inputs,
|
||||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -328,6 +331,10 @@ class InputPreprocessor:
|
|||||||
* :class:`SingletonInputs` instance
|
* :class:`SingletonInputs` instance
|
||||||
"""
|
"""
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
|
if parsed["type"] == "embeds":
|
||||||
|
return self._process_prompt_embeds(parsed)
|
||||||
|
|
||||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||||
self._get_prompt_data(parsed)
|
self._get_prompt_data(parsed)
|
||||||
|
|
||||||
@ -359,6 +366,8 @@ class InputPreprocessor:
|
|||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert_never(parsed)
|
||||||
|
|
||||||
async def _prompt_to_llm_inputs_async(
|
async def _prompt_to_llm_inputs_async(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
prompt: SingletonPrompt,
|
||||||
@ -369,6 +378,9 @@ class InputPreprocessor:
|
|||||||
"""Async version of :meth:`_extract_prompt_components`."""
|
"""Async version of :meth:`_extract_prompt_components`."""
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
|
if parsed["type"] == "embeds":
|
||||||
|
return self._process_prompt_embeds(parsed)
|
||||||
|
|
||||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||||
self._get_prompt_data(parsed)
|
self._get_prompt_data(parsed)
|
||||||
|
|
||||||
@ -399,10 +411,34 @@ class InputPreprocessor:
|
|||||||
cache_salt=cache_salt,
|
cache_salt=cache_salt,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _process_prompt_embeds(self,
|
||||||
|
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
raise ValueError("prompt_embeds is only available in V0.")
|
||||||
|
|
||||||
|
prompt_embeds_content = parsed["content"]
|
||||||
|
|
||||||
|
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
||||||
|
|
||||||
|
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||||
|
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||||
|
# we can unambiguously process the intent by squeezing the batch
|
||||||
|
# dimension.
|
||||||
|
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
||||||
|
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||||
|
|
||||||
|
if prompt_embeds.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||||
|
|
||||||
|
return embeds_inputs(prompt_embeds=prompt_embeds)
|
||||||
|
|
||||||
|
assert_never(parsed)
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
encoder_inputs: SingletonInputs,
|
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
||||||
decoder_inputs: Optional[SingletonInputs],
|
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
||||||
) -> EncoderDecoderInputs:
|
) -> EncoderDecoderInputs:
|
||||||
if (encoder_inputs["type"] == "token"
|
if (encoder_inputs["type"] == "token"
|
||||||
or encoder_inputs["type"] == "multimodal"):
|
or encoder_inputs["type"] == "multimodal"):
|
||||||
@ -410,6 +446,9 @@ class InputPreprocessor:
|
|||||||
else:
|
else:
|
||||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Mypy does not correctly infer that EmbedsInputs is impossible
|
||||||
|
assert "prompt_token_ids" in encoder_inputs
|
||||||
|
|
||||||
if decoder_inputs is None:
|
if decoder_inputs is None:
|
||||||
if self.model_config.hf_config.model_type == "whisper":
|
if self.model_config.hf_config.model_type == "whisper":
|
||||||
# For Whisper models, the text prompt should go to the decoder.
|
# For Whisper models, the text prompt should go to the decoder.
|
||||||
@ -441,7 +480,8 @@ class InputPreprocessor:
|
|||||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonInputs,
|
inputs: SingletonInputs,
|
||||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
||||||
|
MultiModalInputs]] = None,
|
||||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||||
"""
|
"""
|
||||||
For encoder/decoder models only:
|
For encoder/decoder models only:
|
||||||
@ -540,6 +580,8 @@ class InputPreprocessor:
|
|||||||
# For multimodal model, override decoder prompt from processor
|
# For multimodal model, override decoder prompt from processor
|
||||||
# with explicit decoder prompt.
|
# with explicit decoder prompt.
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
|
assert decoder_inputs is None or not is_embeds_inputs(
|
||||||
|
decoder_inputs)
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
encoder_inputs, decoder_inputs))
|
encoder_inputs, decoder_inputs))
|
||||||
@ -555,9 +597,12 @@ class InputPreprocessor:
|
|||||||
inputs))
|
inputs))
|
||||||
else:
|
else:
|
||||||
encoder_inputs = inputs
|
encoder_inputs = inputs
|
||||||
|
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
|
# Mypy does not do type inference well with TypedDicts with Literal
|
||||||
|
# values.
|
||||||
|
assert not is_embeds_inputs(encoder_inputs)
|
||||||
|
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
async def _process_encoder_decoder_prompt_async(
|
||||||
@ -590,6 +635,8 @@ class InputPreprocessor:
|
|||||||
# For multimodal model, override decoder prompt from processor
|
# For multimodal model, override decoder prompt from processor
|
||||||
# with explicit decoder prompt.
|
# with explicit decoder prompt.
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
|
assert decoder_inputs is None or not is_embeds_inputs(
|
||||||
|
decoder_inputs)
|
||||||
encoder_inputs, decoder_inputs = (
|
encoder_inputs, decoder_inputs = (
|
||||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||||
encoder_inputs, decoder_inputs))
|
encoder_inputs, decoder_inputs))
|
||||||
@ -605,9 +652,12 @@ class InputPreprocessor:
|
|||||||
inputs))
|
inputs))
|
||||||
else:
|
else:
|
||||||
encoder_inputs = inputs
|
encoder_inputs = inputs
|
||||||
|
|
||||||
decoder_inputs = None
|
decoder_inputs = None
|
||||||
|
|
||||||
|
# Mypy does not do type inference well with TypedDicts with Literal
|
||||||
|
# values.
|
||||||
|
assert not is_embeds_inputs(encoder_inputs)
|
||||||
|
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||||
|
|
||||||
def _build_decoder_only_llm_inputs(
|
def _build_decoder_only_llm_inputs(
|
||||||
@ -617,10 +667,15 @@ class InputPreprocessor:
|
|||||||
) -> DecoderOnlyInputs:
|
) -> DecoderOnlyInputs:
|
||||||
if (prompt_inputs["type"] == "token"
|
if (prompt_inputs["type"] == "token"
|
||||||
or prompt_inputs["type"] == "multimodal"):
|
or prompt_inputs["type"] == "multimodal"):
|
||||||
|
# Mypy does not do type inference well with typedicts and Literal
|
||||||
|
# values
|
||||||
|
assert not is_embeds_inputs(prompt_inputs)
|
||||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||||
prompt_inputs["prompt_token_ids"],
|
prompt_inputs["prompt_token_ids"],
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
)
|
)
|
||||||
|
elif (prompt_inputs["type"] == "embeds"):
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
assert_never(prompt_inputs) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|||||||
@ -110,6 +110,11 @@ class SamplerOutput(
|
|||||||
# 'broadcasted' to all other PP ranks for next step.
|
# 'broadcasted' to all other PP ranks for next step.
|
||||||
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
sampled_token_ids_cpu: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# On-device tensor containing the sampled token embeddings (embeddings
|
||||||
|
# corresponding to the sampled token ids). Used when prompt embeddings are
|
||||||
|
# specified in lieu of prompt token ids or text.
|
||||||
|
sampled_token_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Spec decode metrics populated by workers.
|
# Spec decode metrics populated by workers.
|
||||||
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None
|
||||||
|
|
||||||
@ -183,7 +188,7 @@ class Sampler(nn.Module):
|
|||||||
|
|
||||||
# Whether or not the SamplerOutput should have on-device tensors
|
# Whether or not the SamplerOutput should have on-device tensors
|
||||||
# containing the sampled token ids and probabilities. This is used by
|
# containing the sampled token ids and probabilities. This is used by
|
||||||
# speculative decoding.
|
# speculative decoding and when prompt embeddings are specified.
|
||||||
self.include_gpu_probs_tensor = False
|
self.include_gpu_probs_tensor = False
|
||||||
self.should_modify_greedy_probs_inplace = False
|
self.should_modify_greedy_probs_inplace = False
|
||||||
|
|
||||||
|
|||||||
@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
|
|||||||
_output_token_ids: array = msgspec.field(
|
_output_token_ids: array = msgspec.field(
|
||||||
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
||||||
|
|
||||||
|
_prompt_embeds: Optional[torch.Tensor] = None
|
||||||
|
_output_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
### The below fields should not be passed as an argument ###
|
### The below fields should not be passed as an argument ###
|
||||||
_cumulative_logprob: float = 0.0
|
_cumulative_logprob: float = 0.0
|
||||||
_prompt_token_ids_tuple: tuple[int,
|
_prompt_token_ids_tuple: tuple[int,
|
||||||
@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
|
|||||||
_num_cached_tokens: int = 0
|
_num_cached_tokens: int = 0
|
||||||
_stage: SequenceStage = SequenceStage.PREFILL
|
_stage: SequenceStage = SequenceStage.PREFILL
|
||||||
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
|
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
|
||||||
|
_cached_all_token_embeds: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||||
# is called.
|
# is called.
|
||||||
@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
|
|||||||
def from_seqs(
|
def from_seqs(
|
||||||
prompt_token_ids: GenericSequence[int],
|
prompt_token_ids: GenericSequence[int],
|
||||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||||
|
*,
|
||||||
|
prompt_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> "SequenceData":
|
) -> "SequenceData":
|
||||||
"""
|
"""
|
||||||
Construct a :class:`SequenceData` instance from prompt and output
|
Construct a :class:`SequenceData` instance from prompt and output
|
||||||
@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
|
|||||||
prompt_token_ids)
|
prompt_token_ids)
|
||||||
|
|
||||||
if output_token_ids is None:
|
if output_token_ids is None:
|
||||||
return SequenceData(prompt_token_ids_arr)
|
return SequenceData(prompt_token_ids_arr,
|
||||||
|
_prompt_embeds=prompt_embeds)
|
||||||
|
|
||||||
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||||
output_token_ids)
|
output_token_ids)
|
||||||
|
|
||||||
return SequenceData(prompt_token_ids_arr,
|
return SequenceData(prompt_token_ids_arr,
|
||||||
_output_token_ids=output_token_ids_arr)
|
_output_token_ids=output_token_ids_arr,
|
||||||
|
_prompt_embeds=prompt_embeds)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self._prompt_token_ids.typecode == "l"
|
assert self._prompt_token_ids.typecode == "l"
|
||||||
@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
|
|||||||
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
|
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
|
||||||
self._prompt_token_ids)
|
self._prompt_token_ids)
|
||||||
self._update_cached_all_tokens()
|
self._update_cached_all_tokens()
|
||||||
|
if self._prompt_embeds is not None:
|
||||||
|
self._update_cached_all_token_embeds()
|
||||||
|
|
||||||
def _update_cached_all_tokens(self):
|
def _update_cached_all_tokens(self):
|
||||||
assert isinstance(self._prompt_token_ids, array)
|
assert isinstance(self._prompt_token_ids, array)
|
||||||
@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
|
|||||||
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
|
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
|
||||||
self._output_token_ids)
|
self._output_token_ids)
|
||||||
|
|
||||||
|
def _update_cached_all_token_embeds(self):
|
||||||
|
assert isinstance(self._prompt_embeds, torch.Tensor)
|
||||||
|
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
|
||||||
|
if self._output_embeds is not None:
|
||||||
|
self._cached_all_token_embeds = torch.cat(
|
||||||
|
(self._cached_all_token_embeds, self._output_embeds), dim=0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cumulative_logprob(self) -> float:
|
def cumulative_logprob(self) -> float:
|
||||||
return self._cumulative_logprob
|
return self._cumulative_logprob
|
||||||
@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
|
|||||||
new_output_token_ids)
|
new_output_token_ids)
|
||||||
self._update_cached_all_tokens()
|
self._update_cached_all_tokens()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_embeds(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._output_embeds
|
||||||
|
|
||||||
|
@output_embeds.setter
|
||||||
|
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
|
||||||
|
self._output_token_embeds = new_output_token_embeds
|
||||||
|
self._update_cached_all_token_embeds()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_token_ids_array(self) -> array:
|
def output_token_ids_array(self) -> array:
|
||||||
"""Return the prompt token ids in array type.
|
"""Return the prompt token ids in array type.
|
||||||
@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
|
|||||||
assert isinstance(self._output_token_ids, array)
|
assert isinstance(self._output_token_ids, array)
|
||||||
return self._output_token_ids
|
return self._output_token_ids
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._prompt_embeds
|
||||||
|
|
||||||
|
@prompt_embeds.setter
|
||||||
|
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
|
||||||
|
self._prompt_embeds = prompt_embeds
|
||||||
|
self._update_cached_all_token_embeds()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mrope_position_delta(self) -> Optional[int]:
|
def mrope_position_delta(self) -> Optional[int]:
|
||||||
return self._mrope_position_delta
|
return self._mrope_position_delta
|
||||||
@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
|
|||||||
def mrope_position_delta(self, new_mrope_position_delta):
|
def mrope_position_delta(self, new_mrope_position_delta):
|
||||||
self._mrope_position_delta = new_mrope_position_delta
|
self._mrope_position_delta = new_mrope_position_delta
|
||||||
|
|
||||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
def append_token_id(self,
|
||||||
|
token_id: int,
|
||||||
|
logprob: float,
|
||||||
|
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||||
self._output_token_ids.append(token_id)
|
self._output_token_ids.append(token_id)
|
||||||
self._new_appended_tokens.append(token_id)
|
self._new_appended_tokens.append(token_id)
|
||||||
self._cached_all_token_ids.append(token_id)
|
self._cached_all_token_ids.append(token_id)
|
||||||
self._cumulative_logprob += logprob
|
self._cumulative_logprob += logprob
|
||||||
|
if token_embed is not None:
|
||||||
|
# Do not pass in with batch or sequence dimensions
|
||||||
|
assert token_embed.ndim == 1
|
||||||
|
token_embed = token_embed.detach().cpu().unsqueeze(0)
|
||||||
|
if self._output_embeds is None:
|
||||||
|
self._output_embeds = token_embed
|
||||||
|
else:
|
||||||
|
self._output_embeds = torch.cat(
|
||||||
|
(self._output_embeds, token_embed), dim=0)
|
||||||
|
assert self._cached_all_token_embeds is not None
|
||||||
|
self._cached_all_token_embeds = torch.cat(
|
||||||
|
(self._cached_all_token_embeds,
|
||||||
|
token_embed.to(device=self._cached_all_token_embeds.device)),
|
||||||
|
dim=0)
|
||||||
|
|
||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
||||||
@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
|
|||||||
def get_token_ids(self) -> list[int]:
|
def get_token_ids(self) -> list[int]:
|
||||||
return self._cached_all_token_ids
|
return self._cached_all_token_ids
|
||||||
|
|
||||||
|
def get_token_embeddings(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._cached_all_token_embeds
|
||||||
|
|
||||||
def get_prefix_token_ids(
|
def get_prefix_token_ids(
|
||||||
self, num_tokens: int
|
self, num_tokens: int
|
||||||
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
|
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
|
||||||
@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"SequenceData("
|
return (f"SequenceData("
|
||||||
f"prompt_token_ids={self._prompt_token_ids}, "
|
f"prompt_token_ids={self._prompt_token_ids}, "
|
||||||
|
f"prompt_embeds.shape="
|
||||||
|
f"{getattr(self._prompt_embeds, 'shape', None)}, "
|
||||||
f"output_token_ids={self.output_token_ids}, "
|
f"output_token_ids={self.output_token_ids}, "
|
||||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||||
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
|
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
|
||||||
@ -425,7 +482,10 @@ class Sequence:
|
|||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
self.prompt_adapter_request = prompt_adapter_request
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
|
|
||||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
self.data = SequenceData.from_seqs(
|
||||||
|
self.prompt_token_ids,
|
||||||
|
prompt_embeds=self.inputs["prompt_embeds"]
|
||||||
|
if self.inputs["type"] == "embeds" else None)
|
||||||
self.output_logprobs: SampleLogprobs = []
|
self.output_logprobs: SampleLogprobs = []
|
||||||
self.output_text = ""
|
self.output_text = ""
|
||||||
|
|
||||||
@ -448,14 +508,20 @@ class Sequence:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt(self) -> Optional[str]:
|
def prompt(self) -> Optional[str]:
|
||||||
|
if self.inputs["type"] == "embeds":
|
||||||
|
return None
|
||||||
return self.inputs.get("prompt")
|
return self.inputs.get("prompt")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prompt_token_ids(self) -> list[int]:
|
def prompt_token_ids(self) -> list[int]:
|
||||||
|
if self.inputs["type"] == "embeds":
|
||||||
|
return [0] * len(self.inputs["prompt_embeds"])
|
||||||
return self.inputs["prompt_token_ids"]
|
return self.inputs["prompt_token_ids"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def token_type_ids(self) -> list[int]:
|
def token_type_ids(self) -> list[int]:
|
||||||
|
if self.inputs["type"] == "embeds":
|
||||||
|
return []
|
||||||
return self.inputs.get("token_type_ids", [])
|
return self.inputs.get("token_type_ids", [])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -554,11 +620,14 @@ class Sequence:
|
|||||||
"""Reset the sequence states for recomputation."""
|
"""Reset the sequence states for recomputation."""
|
||||||
self.data.reset_state_for_recompute()
|
self.data.reset_state_for_recompute()
|
||||||
|
|
||||||
def append_token_id(self, token_id: int, logprobs: dict[int,
|
def append_token_id(self,
|
||||||
Logprob]) -> None:
|
token_id: int,
|
||||||
|
logprobs: dict[int, Logprob],
|
||||||
|
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||||
assert token_id in logprobs
|
assert token_id in logprobs
|
||||||
self.output_logprobs.append(logprobs)
|
self.output_logprobs.append(logprobs)
|
||||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
self.data.append_token_id(token_id, logprobs[token_id].logprob,
|
||||||
|
token_embed)
|
||||||
|
|
||||||
def get_len(self) -> int:
|
def get_len(self) -> int:
|
||||||
return self.data.get_len()
|
return self.data.get_len()
|
||||||
@ -889,6 +958,10 @@ class SequenceGroup:
|
|||||||
f"sampling_params={self.sampling_params}, "
|
f"sampling_params={self.sampling_params}, "
|
||||||
f"num_seqs={len(self.seqs)})")
|
f"num_seqs={len(self.seqs)})")
|
||||||
|
|
||||||
|
def uses_prompt_embeds(self) -> bool:
|
||||||
|
"""Returns True if the sequence group uses input embeds."""
|
||||||
|
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
|
||||||
|
|
||||||
|
|
||||||
class SequenceGroupMetadataDelta(
|
class SequenceGroupMetadataDelta(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
@ -1043,10 +1116,14 @@ class SequenceOutput(
|
|||||||
parent_seq_id: int
|
parent_seq_id: int
|
||||||
output_token: int
|
output_token: int
|
||||||
logprobs: dict[int, Logprob]
|
logprobs: dict[int, Logprob]
|
||||||
|
output_embed: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
|
output_embed_shape = \
|
||||||
|
self.output_embed.shape if self.output_embed is not None else None
|
||||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||||
f"output_token={self.output_token}, "
|
f"output_token={self.output_token}, "
|
||||||
|
f"output_embed.shape={output_embed_shape}"
|
||||||
f"logprobs={self.logprobs})")
|
f"logprobs={self.logprobs})")
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
|
|||||||
@ -201,6 +201,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
if self.prompt_adapter_config is not None:
|
if self.prompt_adapter_config is not None:
|
||||||
raise ValueError("TP1DraftModelRunner has no support for "
|
raise ValueError("TP1DraftModelRunner has no support for "
|
||||||
"prompt_adapter_config")
|
"prompt_adapter_config")
|
||||||
|
if model_input.inputs_embeds is not None:
|
||||||
|
raise ValueError("TP1DraftModelRunner has no support for "
|
||||||
|
"inputs_embeds")
|
||||||
if model_input.multi_modal_kwargs:
|
if model_input.multi_modal_kwargs:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
"TP1DraftModelRunner has no support for multi_modal_kwargs"
|
||||||
@ -242,9 +245,16 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
|
|
||||||
# Get model
|
# Get model
|
||||||
if use_cuda_graph:
|
if use_cuda_graph:
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
if model_input.inputs_embeds is None:
|
||||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
[graph_batch_size])
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, False)])
|
||||||
|
else:
|
||||||
|
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||||
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, True)])
|
||||||
|
|
||||||
if previous_hidden_states is not None:
|
if previous_hidden_states is not None:
|
||||||
hidden_states = torch.cat([
|
hidden_states = torch.cat([
|
||||||
@ -281,6 +291,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
self.vllm_config):
|
self.vllm_config):
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
|
inputs_embeds=None,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
|
|||||||
@ -282,7 +282,8 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase):
|
|||||||
else:
|
else:
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
seq.append_token_id(token_id, token_logprob.logprob)
|
seq.append_token_id(token_id, token_logprob.logprob,
|
||||||
|
seq_output.output_embed)
|
||||||
seq.update_num_computed_tokens(1)
|
seq.update_num_computed_tokens(1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata):
|
|||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
|
"inputs_embeds": self.inputs_embeds,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
"encoder_input_tokens": self.encoder_input_tokens,
|
"encoder_input_tokens": self.encoder_input_tokens,
|
||||||
"encoder_input_positions": self.encoder_input_positions,
|
"encoder_input_positions": self.encoder_input_positions,
|
||||||
@ -172,10 +173,17 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
if (model_input.attn_metadata is not None
|
if (model_input.attn_metadata is not None
|
||||||
and model_input.attn_metadata.prefill_metadata is None
|
and model_input.attn_metadata.prefill_metadata is None
|
||||||
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
and model_input.attn_metadata.decode_metadata.use_cuda_graph):
|
||||||
assert model_input.input_tokens is not None
|
if model_input.inputs_embeds is None:
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
assert model_input.input_tokens is not None
|
||||||
model_executable = self.graph_runners[
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_input.virtual_engine][graph_batch_size]
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, False)])
|
||||||
|
else:
|
||||||
|
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||||
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, True)])
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
@ -189,6 +197,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
model_input.virtual_engine):
|
model_input.virtual_engine):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
|
inputs_embeds=model_input.inputs_embeds,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
encoder_input_ids=model_input.encoder_input_tokens,
|
encoder_input_ids=model_input.encoder_input_tokens,
|
||||||
encoder_positions=model_input.encoder_input_positions,
|
encoder_positions=model_input.encoder_input_positions,
|
||||||
|
|||||||
@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
from vllm.model_executor import SamplingMetadata, SamplingMetadataCache
|
||||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput,
|
||||||
|
get_sampler)
|
||||||
from vllm.model_executor.model_loader import get_model
|
from vllm.model_executor.model_loader import get_model
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||||
@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
additional fields.
|
additional fields.
|
||||||
"""
|
"""
|
||||||
input_tokens: Optional[torch.Tensor] = None
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None
|
||||||
input_positions: Optional[torch.Tensor] = None
|
input_positions: Optional[torch.Tensor] = None
|
||||||
token_types: Optional[torch.Tensor] = None
|
token_types: Optional[torch.Tensor] = None
|
||||||
seq_lens: Optional[List[int]] = None
|
seq_lens: Optional[List[int]] = None
|
||||||
@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
|
"inputs_embeds": self.inputs_embeds,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
"lora_requests": self.lora_requests,
|
"lora_requests": self.lora_requests,
|
||||||
"lora_mapping": self.lora_mapping,
|
"lora_mapping": self.lora_mapping,
|
||||||
@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
|||||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
tensor_dict = {
|
tensor_dict = {
|
||||||
"input_tokens": self.input_tokens,
|
"input_tokens": self.input_tokens,
|
||||||
|
"inputs_embeds": self.inputs_embeds,
|
||||||
"input_positions": self.input_positions,
|
"input_positions": self.input_positions,
|
||||||
"lora_requests": self.lora_requests,
|
"lora_requests": self.lora_requests,
|
||||||
"lora_mapping": self.lora_mapping,
|
"lora_mapping": self.lora_mapping,
|
||||||
@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
def simple_reinit(self):
|
def simple_reinit(self):
|
||||||
self.input_tokens[0].clear() # type: ignore
|
self.input_tokens[0].clear() # type: ignore
|
||||||
|
self.inputs_embeds = None # type: ignore
|
||||||
self.input_positions[0].clear() # type: ignore
|
self.input_positions[0].clear() # type: ignore
|
||||||
self.token_types[0].clear() # type: ignore
|
self.token_types[0].clear() # type: ignore
|
||||||
self.mrope_input_positions = None # type: ignore
|
self.mrope_input_positions = None # type: ignore
|
||||||
@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
# Input tokens and positions.
|
# Input tokens and positions.
|
||||||
input_tokens: Optional[List[List[int]]] = None,
|
input_tokens: Optional[List[List[int]]] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
input_positions: Optional[List[List[int]]] = None,
|
input_positions: Optional[List[List[int]]] = None,
|
||||||
token_types: Optional[List[List[int]]] = None,
|
token_types: Optional[List[List[int]]] = None,
|
||||||
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
mrope_input_positions: Optional[List[List[List[int]]]] = None,
|
||||||
@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
for seq_id in range(len(self.seq_ids)):
|
for seq_id in range(len(self.seq_ids)):
|
||||||
self.input_tokens[seq_id].clear()
|
self.input_tokens[seq_id].clear()
|
||||||
|
|
||||||
|
self.inputs_embeds = inputs_embeds
|
||||||
|
|
||||||
if input_positions:
|
if input_positions:
|
||||||
self.input_positions = input_positions
|
self.input_positions = input_positions
|
||||||
else:
|
else:
|
||||||
@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.input_tokens = input_tokens or []
|
self.input_tokens = input_tokens or []
|
||||||
|
self.inputs_embeds = inputs_embeds
|
||||||
self.input_positions = input_positions or []
|
self.input_positions = input_positions or []
|
||||||
self.token_types = token_types or []
|
self.token_types = token_types or []
|
||||||
self.mrope_input_positions = mrope_input_positions or None
|
self.mrope_input_positions = mrope_input_positions or None
|
||||||
@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
self.lora_index_mapping = []
|
self.lora_index_mapping = []
|
||||||
self.lora_prompt_mapping = []
|
self.lora_prompt_mapping = []
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"InterDataForSeqGroup("
|
||||||
|
f"request_id={self.request_id}, "
|
||||||
|
f"seq_ids={self.seq_ids}, "
|
||||||
|
f"is_prompt={self.is_prompt}, "
|
||||||
|
f"block_tables={self.block_tables}, "
|
||||||
|
f"computed_block_nums={self.computed_block_nums}, "
|
||||||
|
f"n_seqs={self.n_seqs}, "
|
||||||
|
f"input_tokens={self.input_tokens}, "
|
||||||
|
f"inputs_embeds.shape="
|
||||||
|
f"{getattr(self.inputs_embeds, 'shape', None)}, "
|
||||||
|
f"input_positions={self.input_positions}, "
|
||||||
|
f"token_types={self.token_types}, "
|
||||||
|
f"mrope_input_positions={self.mrope_input_positions}, "
|
||||||
|
f"seq_lens={self.seq_lens}, "
|
||||||
|
f"orig_seq_lens={self.orig_seq_lens}, "
|
||||||
|
f"query_lens={self.query_lens}, "
|
||||||
|
f"context_lens={self.context_lens}, "
|
||||||
|
f"multi_modal_kwargs={self.multi_modal_kwargs}")
|
||||||
|
|
||||||
def gen_inter_data_builder(self, num_seqs: int):
|
def gen_inter_data_builder(self, num_seqs: int):
|
||||||
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
|
return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup(
|
||||||
request_id="",
|
request_id="",
|
||||||
@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
context_len = seq_data.get_num_computed_tokens()
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
|
||||||
# Compute tokens.
|
# Compute tokens.
|
||||||
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
if seq_data.prompt_embeds is None:
|
||||||
|
tokens = seq_data.get_token_ids()[context_len:seq_len]
|
||||||
|
prompt_embeds = None
|
||||||
|
else:
|
||||||
|
tokens = [0] * (seq_len - context_len)
|
||||||
|
prompt_embeds = seq_data.get_token_embeddings(
|
||||||
|
)[context_len:seq_len]
|
||||||
|
|
||||||
token_types = seq_group_metadata.token_type_ids
|
token_types = seq_group_metadata.token_type_ids
|
||||||
|
|
||||||
inter_data.seq_lens[seq_idx] = seq_len
|
inter_data.seq_lens[seq_idx] = seq_len
|
||||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||||
inter_data.context_lens[seq_idx] = context_len
|
inter_data.context_lens[seq_idx] = context_len
|
||||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||||
|
inter_data.inputs_embeds = prompt_embeds
|
||||||
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
inter_data.input_positions[seq_idx].extend(range(context_len, seq_len))
|
||||||
inter_data.token_types[seq_idx].extend(
|
inter_data.token_types[seq_idx].extend(
|
||||||
token_types if token_types else [])
|
token_types if token_types else [])
|
||||||
@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
create on-device tensors.
|
create on-device tensors.
|
||||||
"""
|
"""
|
||||||
# Combine and flatten intermediate data.
|
# Combine and flatten intermediate data.
|
||||||
input_tokens = []
|
input_tokens = list[int]()
|
||||||
token_types = []
|
inputs_embeds_lst = list[torch.Tensor]()
|
||||||
|
token_types = list[int]()
|
||||||
for inter_data in self.inter_data_list:
|
for inter_data in self.inter_data_list:
|
||||||
for cur_input_tokens in inter_data.input_tokens:
|
for cur_input_tokens in inter_data.input_tokens:
|
||||||
input_tokens.extend(cur_input_tokens)
|
input_tokens.extend(cur_input_tokens)
|
||||||
for cur_token_types in inter_data.token_types:
|
for cur_token_types in inter_data.token_types:
|
||||||
token_types.extend(cur_token_types)
|
token_types.extend(cur_token_types)
|
||||||
|
if inter_data.inputs_embeds is not None:
|
||||||
|
inputs_embeds_lst.append(
|
||||||
|
inter_data.inputs_embeds.to(
|
||||||
|
dtype=self.runner.model_config.dtype,
|
||||||
|
device=self.runner.device))
|
||||||
|
inputs_embeds: Optional[torch.Tensor]
|
||||||
|
if len(inputs_embeds_lst) == 0:
|
||||||
|
inputs_embeds = None
|
||||||
|
else:
|
||||||
|
inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to(
|
||||||
|
dtype=self.runner.model_config.dtype,
|
||||||
|
device=self.runner.device)
|
||||||
|
assert len(inputs_embeds) == len(input_tokens)
|
||||||
|
|
||||||
if not input_tokens:
|
if not input_tokens and inputs_embeds is None:
|
||||||
# This may happen when all prefill requests hit
|
# This may happen when all prefill requests hit
|
||||||
# prefix caching and there is no decode request.
|
# prefix caching and there is no decode request.
|
||||||
return self.model_input_cls()
|
return self.model_input_cls()
|
||||||
@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||||||
|
|
||||||
return self.model_input_cls(
|
return self.model_input_cls(
|
||||||
input_tokens=input_tokens_tensor,
|
input_tokens=input_tokens_tensor,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
input_positions=input_positions_tensor,
|
input_positions=input_positions_tensor,
|
||||||
token_types=token_types_tensor,
|
token_types=token_types_tensor,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
self.max_batchsize_to_capture = \
|
self.max_batchsize_to_capture = \
|
||||||
self.vllm_config.compilation_config.max_capture_size
|
self.vllm_config.compilation_config.max_capture_size
|
||||||
|
|
||||||
self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
|
#
|
||||||
|
self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [
|
||||||
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
{} for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
]
|
]
|
||||||
self.graph_memory_pool: Optional[Tuple[
|
self.graph_memory_pool: Optional[Tuple[
|
||||||
@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
input_positions = torch.zeros(max_batch_size,
|
input_positions = torch.zeros(max_batch_size,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
|
inputs_embeds = torch.zeros(
|
||||||
|
(max_batch_size, self.model_config.get_hidden_size()),
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device)
|
||||||
if self.model_config.uses_mrope:
|
if self.model_config.uses_mrope:
|
||||||
input_positions = torch.tile(input_positions,
|
input_positions = torch.tile(input_positions,
|
||||||
(3, 1)).cuda(device=self.device)
|
(3, 1)).cuda(device=self.device)
|
||||||
@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
# memory usage of CUDA graph.
|
# memory usage of CUDA graph.
|
||||||
for virtual_engine in range(
|
for virtual_engine in range(
|
||||||
self.parallel_config.pipeline_parallel_size):
|
self.parallel_config.pipeline_parallel_size):
|
||||||
# Only rank 0 should print progress bar during capture
|
# We need to not only iterate over batch sizes, but also whether
|
||||||
cudagraph_capture_sizes = (tqdm(
|
# to use inputs_embeds or not, hence we use the cartesian
|
||||||
self.vllm_config.compilation_config.
|
# product.
|
||||||
|
cudagraph_capture_sizes = self.vllm_config.compilation_config\
|
||||||
|
.cudagraph_capture_sizes
|
||||||
|
cudagraph_inputs_embeds = (True, False)
|
||||||
|
compilation_cases = itertools.product(
|
||||||
cudagraph_capture_sizes,
|
cudagraph_capture_sizes,
|
||||||
desc="Capturing CUDA graph shapes",
|
cudagraph_inputs_embeds,
|
||||||
) if get_tensor_model_parallel_rank() == 0 else
|
)
|
||||||
self.vllm_config.compilation_config.
|
# Only rank 0 should print progress bar during capture
|
||||||
cudagraph_capture_sizes)
|
if get_tensor_model_parallel_rank() == 0:
|
||||||
for batch_size in cudagraph_capture_sizes:
|
compilation_cases = tqdm(
|
||||||
|
list(compilation_cases),
|
||||||
|
desc="Capturing CUDA graph shapes")
|
||||||
|
for batch_size, use_inputs_embeds in compilation_cases:
|
||||||
attn_metadata = (
|
attn_metadata = (
|
||||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||||
batch_size,
|
batch_size,
|
||||||
@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
capture_inputs = {
|
capture_inputs = {
|
||||||
"input_ids":
|
"input_ids":
|
||||||
input_tokens[:batch_size],
|
input_tokens[:batch_size],
|
||||||
|
"inputs_embeds":
|
||||||
|
inputs_embeds[:batch_size]
|
||||||
|
if use_inputs_embeds else None,
|
||||||
"positions":
|
"positions":
|
||||||
input_positions[..., :batch_size],
|
input_positions[..., :batch_size],
|
||||||
"intermediate_inputs":
|
"intermediate_inputs":
|
||||||
@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
virtual_engine):
|
virtual_engine):
|
||||||
graph_runner.capture(**capture_inputs)
|
graph_runner.capture(**capture_inputs)
|
||||||
self.graph_memory_pool = graph_runner.graph.pool()
|
self.graph_memory_pool = graph_runner.graph.pool()
|
||||||
self.graph_runners[virtual_engine][batch_size] = (
|
self.graph_runners[virtual_engine][(
|
||||||
graph_runner)
|
batch_size, use_inputs_embeds)] = graph_runner
|
||||||
|
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
self._remove_dummy_loras()
|
self._remove_dummy_loras()
|
||||||
@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
if prefill_meta is None and decode_meta.use_cuda_graph:
|
if prefill_meta is None and decode_meta.use_cuda_graph:
|
||||||
assert model_input.input_tokens is not None
|
assert model_input.input_tokens is not None
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_executable = self.graph_runners[virtual_engine][
|
use_inputs_embeds = model_input.inputs_embeds is not None
|
||||||
graph_batch_size]
|
model_executable = self.graph_runners[virtual_engine][(
|
||||||
|
graph_batch_size, use_inputs_embeds)]
|
||||||
if previous_hidden_states is not None:
|
if previous_hidden_states is not None:
|
||||||
previous_hidden_states = torch.cat([
|
previous_hidden_states = torch.cat([
|
||||||
previous_hidden_states,
|
previous_hidden_states,
|
||||||
@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
self.vllm_config, virtual_engine):
|
self.vllm_config, virtual_engine):
|
||||||
hidden_or_intermediate_states = model_executable(
|
hidden_or_intermediate_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
|
inputs_embeds=model_input.inputs_embeds,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
||||||
@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
model_input.async_callback()
|
model_input.async_callback()
|
||||||
|
|
||||||
# Sample the next token.
|
# Sample the next token.
|
||||||
|
assert isinstance(self.sampler, Sampler)
|
||||||
|
orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor
|
||||||
|
if model_input.inputs_embeds is not None:
|
||||||
|
self.sampler.include_gpu_probs_tensor = True
|
||||||
|
|
||||||
output: SamplerOutput = self.sampler(
|
output: SamplerOutput = self.sampler(
|
||||||
logits=logits,
|
logits=logits,
|
||||||
sampling_metadata=model_input.sampling_metadata,
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
output.model_forward_time = (orig_model_forward_time +
|
output.model_forward_time = (orig_model_forward_time +
|
||||||
model_forward_time)
|
model_forward_time)
|
||||||
|
|
||||||
|
if model_input.inputs_embeds is not None:
|
||||||
|
self.sampler.include_gpu_probs_tensor = \
|
||||||
|
orig_include_gpu_probs_tensor
|
||||||
|
if output.sampled_token_ids is not None:
|
||||||
|
output.sampled_token_embeds = self.model.get_input_embeddings(
|
||||||
|
output.sampled_token_ids.squeeze(1))
|
||||||
|
|
||||||
|
for token_embed, sequence_group_output in zip(
|
||||||
|
output.sampled_token_embeds, output.outputs):
|
||||||
|
assert len(sequence_group_output.samples) == 1
|
||||||
|
sequence_group_output.samples[0].output_embed = token_embed
|
||||||
|
|
||||||
if self.return_hidden_states:
|
if self.return_hidden_states:
|
||||||
# we only need to pass hidden states of most recent token
|
# we only need to pass hidden states of most recent token
|
||||||
assert model_input.sampling_metadata is not None
|
assert model_input.sampling_metadata is not None
|
||||||
@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
def capture(
|
def capture(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_inputs: Optional[IntermediateTensors],
|
intermediate_inputs: Optional[IntermediateTensors],
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
for _ in range(_NUM_WARMUP_ITERS):
|
for _ in range(_NUM_WARMUP_ITERS):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||||
output_hidden_or_intermediate_states = self.model(
|
output_hidden_or_intermediate_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
|
**({
|
||||||
|
"inputs_embeds": inputs_embeds,
|
||||||
|
} if inputs_embeds is not None else {}),
|
||||||
positions=positions,
|
positions=positions,
|
||||||
intermediate_tensors=intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
self.input_buffers = {
|
self.input_buffers = {
|
||||||
"input_ids":
|
"input_ids":
|
||||||
input_ids,
|
input_ids,
|
||||||
|
**({
|
||||||
|
"inputs_embeds": inputs_embeds,
|
||||||
|
} if inputs_embeds is not None else {}),
|
||||||
"positions":
|
"positions":
|
||||||
positions,
|
positions,
|
||||||
"kv_caches":
|
"kv_caches":
|
||||||
@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors],
|
intermediate_tensors: Optional[IntermediateTensors],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module):
|
|||||||
# so the shape is not padded, we need to copy partial only
|
# so the shape is not padded, we need to copy partial only
|
||||||
self.input_buffers["positions"][:positions.shape[0]].copy_(
|
self.input_buffers["positions"][:positions.shape[0]].copy_(
|
||||||
positions, non_blocking=True)
|
positions, non_blocking=True)
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_(
|
||||||
|
inputs_embeds, non_blocking=True)
|
||||||
|
|
||||||
if self.backend_name != "NO_ATTENTION":
|
if self.backend_name != "NO_ATTENTION":
|
||||||
self.input_buffers["slot_mapping"].copy_(
|
self.input_buffers["slot_mapping"].copy_(
|
||||||
|
|||||||
@ -84,10 +84,17 @@ class PoolingModelRunner(
|
|||||||
# explore how to leverage it.
|
# explore how to leverage it.
|
||||||
if (prefill_meta is None and decode_meta is not None
|
if (prefill_meta is None and decode_meta is not None
|
||||||
and decode_meta.use_cuda_graph):
|
and decode_meta.use_cuda_graph):
|
||||||
assert model_input.input_tokens is not None
|
if model_input.inputs_embeds is None:
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
assert model_input.input_tokens is not None
|
||||||
model_executable = self.graph_runners[virtual_engine][
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
graph_batch_size]
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, False)])
|
||||||
|
else:
|
||||||
|
graph_batch_size = model_input.inputs_embeds.shape[0]
|
||||||
|
model_executable = (
|
||||||
|
self.graph_runners[model_input.virtual_engine][(
|
||||||
|
graph_batch_size, True)])
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user