diff --git a/tests/conftest.py b/tests/conftest.py index 14a88ca47505..571cca8eeccb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -787,7 +787,7 @@ class VllmRunner: def get_inputs( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, @@ -809,16 +809,18 @@ class VllmRunner: if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - inputs.append( - TextPrompt(prompt=prompt, - multi_modal_data=multi_modal_data - if multi_modal_data else None)) + text_prompt_kwargs = { + ("prompt" if isinstance(prompt, str) else "prompt_embeds"): + prompt, + "multi_modal_data": multi_modal_data or None + } + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs def generate( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, @@ -844,7 +846,7 @@ class VllmRunner: output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -911,7 +913,7 @@ class VllmRunner: def generate_greedy( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 8bd64923fe22..a5ba16898d89 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -2,16 +2,18 @@ import time from collections import deque +from typing import Optional from unittest.mock import MagicMock import pytest # noqa +import torch from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus from .utils import (append_new_token, append_new_token_seq, append_new_token_seq_group, create_dummy_prompt, @@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( ), "A partial prefix of C (4 tokens) should be prefilled, with the " "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "then be rounded down to 2 tokens on block size, thus 6 tokens in total." + + +def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): + """ + Test that the scheduler does not schedule batches with prompt tokens and + prompt embeddings co-mingled. + """ + block_size = 2 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + ) + + # the odd indexed inputs should be passed in via embeddings, + # evens via token_ids + seq_length = 7 + embedding_size = 5 + num_seqs = 11 + seq_tokens: list[list[int]] = [] + seq_embeds: list[Optional[torch.Tensor]] = [] + for i in range(num_seqs): + if i % 2: + seq_tokens.append(list(range(seq_length))) + seq_embeds.append(None) + else: + seq_tokens.append([0] * seq_length) + seq_embeds.append(torch.rand(embedding_size)) + + seq_and_seq_groups = [ + create_dummy_prompt(f"{i}", + prompt_tokens=seq_tokens[i], + prompt_embeds=seq_embeds[i], + block_size=block_size) + for i in range(len(seq_tokens)) + ] + + for _, seq_group in seq_and_seq_groups: + scheduler.add_seq_group(seq_group) + + while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): + unfinished_seq_groups = [ + seq_group for _, seq_group in seq_and_seq_groups + if not seq_group.is_finished() + ] + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) > 0 + batch_is_prompt_embeds = out.scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + expected_scheduled_seq_groups = [ + seq_group for seq_group in unfinished_seq_groups + if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds + ] + + # We should have as many scheduled groups as possible, without mixing + assert len(out.scheduled_seq_groups) == min( + max_seq_group, len(expected_scheduled_seq_groups)) + assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == + batch_is_prompt_embeds + for scheduled_seq_group in out.scheduled_seq_groups) + + # Finish the scheduled groups + for scheduled_seq_group in out.scheduled_seq_groups: + for seq in scheduled_seq_group.seq_group.seqs: + seq.status = SequenceStatus.FINISHED_STOPPED + scheduler.free_finished_seq_groups() diff --git a/tests/core/utils.py b/tests/core/utils.py index ea18b879a317..84b0426b470b 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -5,9 +5,11 @@ from collections import defaultdict from collections.abc import Sequence as GenericSequence from typing import Any, Optional +import torch + from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, token_inputs +from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupMetadata) @@ -19,6 +21,7 @@ def create_dummy_prompt( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_tokens: Optional[list[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, min_tokens: int = 0, max_tokens: int = 16, ) -> tuple[Sequence, SequenceGroup]: @@ -31,9 +34,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) + inputs = token_inputs( + prompt_token_ids=prompt_tokens, + prompt=prompt_str) if prompt_embeds is None else embeds_inputs( + prompt_embeds=prompt_embeds) prompt = Sequence( int(request_id), - inputs=token_inputs(prompt_tokens, prompt=prompt_str), + inputs=inputs, block_size=block_size, ) seq_group = SequenceGroup( diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index ab2898ffb2d0..fcd3fa036cfd 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -1,4 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 +import os +from typing import Optional + import pytest import torch @@ -110,6 +113,18 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) + prompt_embeds: Optional[list[torch.Tensor]] = [] if os.getenv( + "VLLM_USE_V1") == "0" else None + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + if prompt_embeds is not None: + prompt_embeds.append(hf_model.model.get_input_embeddings()( + token_ids).squeeze(0)) + with vllm_runner( model, tokenizer_name=model_info.tokenizer or model, @@ -119,6 +134,9 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + if prompt_embeds is not None: + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -126,6 +144,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) + if prompt_embeds is not None: + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) + if use_rocm_aiter: # this is to ensure that vllm engine # has deallocated the memory before running the next diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index b8ba69b0dd8f..a1bdea687a85 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -31,8 +31,13 @@ def test_deepseek_mla_attn_backend_module(): assert model_runner.attn_backend.__name__ == "TritonMLABackend" -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_prompt(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", max_num_batched_tokens=100000, @@ -43,11 +48,20 @@ def test_prepare_prompt(batch_size): seq_lens: list[int] = [] seq_group_metadata_list: list[SequenceGroupMetadata] = [] block_tables = {0: [1]} + expected_input_embeds_len = 0 for i in range(batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * seq_len, + prompt_embeds=torch.rand(seq_len, 10), + ) + expected_input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len)) + seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -68,6 +82,7 @@ def test_prepare_prompt(batch_size): seq_group_metadata_list) input_tokens = model_input.input_tokens input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds attn_metadata = model_input.attn_metadata return_seq_lens = model_input.seq_lens slot_mapping = attn_metadata.slot_mapping @@ -121,7 +136,11 @@ def test_prepare_prompt(batch_size): assert len(input_tokens) == sum(seq_lens) assert len(input_positions) == sum(seq_lens) - torch.testing.assert_close(input_tokens, input_positions) + if expected_input_embeds_len == 0: + torch.testing.assert_close(input_tokens, input_positions) + assert input_embeds is None + else: + assert len(input_embeds) == expected_input_embeds_len sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, @@ -145,8 +164,13 @@ def test_prepare_prompt(batch_size): torch.testing.assert_close(actual, expected) -@pytest.mark.parametrize("batch_size", list(range(1, 257))) -def test_prepare_decode_cuda_graph(batch_size): +@pytest.mark.parametrize("batch_size", list(range(1, 257, 3))) +@pytest.mark.parametrize("use_prompt_embeds", [True, False]) +def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -164,10 +188,19 @@ def test_prepare_decode_cuda_graph(batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 context_lens.append(context_len) - seq_data = SequenceData.from_seqs(range(context_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * context_len, + prompt_embeds=torch.rand(context_len, 10), + ) + output_embed = torch.rand(10) + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len)) + output_embed = None seq_data.update_num_computed_tokens(context_len) # Append one token ID since prefill is finished. - seq_data.append_token_id(1, 0) + seq_data.append_token_id(1, 0, output_embed) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=False, @@ -180,9 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size): model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata, slot_mapping = ( - model_input.input_tokens, model_input.input_positions, - model_input.attn_metadata, model_input.attn_metadata.slot_mapping) + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + slot_mapping = attn_metadata.slot_mapping + assert len(slot_mapping) == len(input_tokens) expected_bs = model_runner.vllm_config.pad_for_cudagraph( @@ -227,7 +263,7 @@ def test_prepare_decode_cuda_graph(batch_size): # block table's first index corresponds to each batch, meaning in # decoding it is each token. assert attn_metadata.block_tables.shape[0] == len(input_tokens) - # Block table's second dim correspondsd to each token's block number. + # Block table's second dim corresponds to each token's block number. # It is padded up to assert attn_metadata.block_tables.shape[1] == ( model_runner.get_max_block_per_batch()) @@ -235,7 +271,12 @@ def test_prepare_decode_cuda_graph(batch_size): assert len(input_tokens) == expected_bs assert len(input_positions) == expected_bs - torch.allclose(input_tokens, input_positions) + if use_prompt_embeds: + expected_input_embeds_length = start_loc[-1] + assert len(input_embeds) == expected_input_embeds_length + assert expected_input_embeds_length <= expected_bs + else: + assert input_embeds is None # Verify Sampling expected_selected_token_indices = [] @@ -266,25 +307,27 @@ def test_empty_seq_group(): seq_group_metadata_list: list[SequenceGroupMetadata] = [] model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - input_tokens, input_positions, attn_metadata = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + attn_metadata = model_input.attn_metadata + assert input_tokens is None assert input_positions is None assert attn_metadata is None model_input = model_runner._prepare_model_input_tensors( seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata, return_seq_lens) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - model_input.seq_lens, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata + return_seq_lens = model_input.seq_lens + assert input_tokens is None assert input_positions is None + assert input_embeds is None assert attn_metadata is None assert return_seq_lens is None @@ -299,9 +342,15 @@ def distributed_init(): ensure_model_parallel_initialized(1, 1) -@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("batch_size", list(range(2, 128, 3))) @pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, distributed_init): +@pytest.mark.parametrize('use_prompt_embeds', [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds, + distributed_init, monkeypatch): + if use_prompt_embeds: + # Prompt Embeddings is only currently supported on V0 + monkeypatch.setenv("VLLM_USE_V1", "0") + model_runner = _create_model_runner( "facebook/opt-125m", seed=0, @@ -320,11 +369,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): block_tables = {0: [1]} prefill_batch_size = batch_size // 2 decode_batch_size = batch_size - prefill_batch_size + expected_input_embeds_len = 0 for i in range(prefill_batch_size): # make sure all tokens fit into one block seq_len = i % (model_runner.block_size - 1) + 1 seq_lens.append(seq_len) - seq_data = SequenceData.from_seqs(range(seq_len)) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * seq_len, + prompt_embeds=torch.rand(seq_len, 10), + ) + expected_input_embeds_len += seq_len + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(seq_len), ) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, @@ -340,8 +398,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): for i in range(prefill_batch_size, batch_size): # make sure all tokens fit into one block context_len = i % (model_runner.block_size - 1) + 1 - seq_data = SequenceData.from_seqs(range(context_len)) - seq_data.append_token_id(1, 0) + if use_prompt_embeds: + seq_data = SequenceData.from_seqs( + prompt_token_ids=[0] * context_len, + prompt_embeds=torch.rand(context_len, 10), + ) + output_embed = torch.rand(10) + # This also iterates the expected input_embeds, because the model + # needs both the input and output embeddings passed into together + expected_input_embeds_len += 1 + else: + seq_data = SequenceData.from_seqs( + prompt_token_ids=range(context_len), ) + output_embed = None + assert len(seq_data.prompt_token_ids) == context_len + seq_data.append_token_id(1, 0, output_embed) seq_data.update_num_computed_tokens(context_len) seq_group_metadata = SequenceGroupMetadata( request_id=f"test_{i}", @@ -355,11 +426,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): decode_metadata_list.append(seq_group_metadata) model_input = model_runner.prepare_model_input(seq_group_metadata_list) - (input_tokens, input_positions, attn_metadata) = ( - model_input.input_tokens, - model_input.input_positions, - model_input.attn_metadata, - ) + + input_tokens = model_input.input_tokens + input_positions = model_input.input_positions + input_embeds = model_input.inputs_embeds + attn_metadata = model_input.attn_metadata prefill_meta_actual = attn_metadata.prefill_metadata decode_meta_actual = attn_metadata.decode_metadata @@ -369,6 +440,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init): assert attn_metadata.num_prefills == prefill_batch_size assert attn_metadata.num_decode_tokens == decode_batch_size assert attn_metadata.num_prefill_tokens == sum(seq_lens) + if expected_input_embeds_len == 0: + assert input_embeds is None + else: + assert len(input_embeds) == expected_input_embeds_len # Verify attn metadata is consistent. We don't need to test individual # values here because they are tested above. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index d92177d58a48..37b20d0739f7 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -367,9 +367,17 @@ class FlashInferState(AttentionState): # scheduled while CUDA graph mode is enabled. We don't run graph in that # case. if use_cuda_graph and is_decode: - batch_size = model_input.input_tokens.shape[0] - state = (self.runner.graph_runners[model_input.virtual_engine] - [batch_size].attn_state) + if model_input.inputs_embeds is None: + batch_size = model_input.input_tokens.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, False)].attn_state) + else: + batch_size = model_input.inputs_embeds.shape[0] + state = ( + self.runner.graph_runners[model_input.virtual_engine][( + batch_size, True)].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( ) model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 97d03d5e3b40..06d4ed470b20 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1071,6 +1071,7 @@ class Scheduler: ) ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + using_prompt_embeds: bool = False waiting_queue = self.waiting @@ -1138,6 +1139,15 @@ class Scheduler: waiting_queue.popleft() continue + # We cannot mix sequence groups that use prompt embeds and + # those that do not. + if len(seq_groups) == 0: + using_prompt_embeds = seq_group.uses_prompt_embeds() + if using_prompt_embeds != seq_group.uses_prompt_embeds(): + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id @@ -1295,17 +1305,39 @@ class Scheduler: # Merge lists num_prefill_groups = len(prefills.seq_groups) + ignored_seq_groups_for_embeds = list[SequenceGroup]() if num_prefill_groups > 0: scheduled_seq_groups = prefills.seq_groups scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + ignored_seq_groups_for_embeds.clear() else: scheduled_seq_groups = running_scheduled.decode_seq_groups + if len(scheduled_seq_groups) > 0: + using_prompt_embeds = scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + ignored_seq_groups_for_embeds.clear() + indices_ignored = list[int]() + for i, schedule_seq_group in enumerate(scheduled_seq_groups): + if using_prompt_embeds !=\ + schedule_seq_group.seq_group.uses_prompt_embeds(): + ignored_seq_groups_for_embeds.append( + schedule_seq_group.seq_group) + indices_ignored.append(i) + if len(ignored_seq_groups_for_embeds) > 0: + scheduled_seq_groups = [ + group for i, group in enumerate(scheduled_seq_groups) + if i not in indices_ignored + ] + else: + ignored_seq_groups_for_embeds.clear() + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) blocks_to_copy = running_scheduled.blocks_to_copy blocks_to_copy.extend(swapped_in.blocks_to_copy) ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(ignored_seq_groups_for_embeds) ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) return SchedulerOutputs( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6cc9b881464e..cb0902c3a5b8 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -489,6 +489,14 @@ class _AsyncLLMEngine(LLMEngine): if arrival_time is None: arrival_time = time.time() + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + # We use the -2 dimension (instead of 0) in case a batched input + # of batch size 1 is passed in. + prompt["prompt_token_ids"] = [0 + ] * prompt["prompt_embeds"].shape[-2] + if self.tokenizer is not None: tokenizer = await self.get_tokenizer_async(lora_request) self._validate_token_prompt(prompt, tokenizer=tokenizer) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0930bae02e41..142c8fe99b67 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -753,6 +753,12 @@ class LLMEngine: if arrival_time is None: arrival_time = time.time() + if (isinstance(prompt, dict) + and prompt.get("prompt_embeds", None) is not None + and not prompt.get("prompt_token_ids", None)): + seq_len = prompt["prompt_embeds"].shape[0] + prompt["prompt_token_ids"] = [0] * seq_len + if self.tokenizer is not None: self._validate_token_prompt( prompt, @@ -1267,11 +1273,13 @@ class LLMEngine: if self.scheduler_config.is_multi_step: is_prefill_append = seq.data.get_num_uncomputed_tokens( ) == 0 - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if not is_prefill_append: seq_group.update_num_computed_tokens(1) else: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -2032,10 +2040,12 @@ class LLMEngine: tokenizer = (None if self.tokenizer is None else self.tokenizer.get_lora_tokenizer(lora_request)) - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = prompt_inputs.get("prompt_token_ids", []) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + if prompt_inputs["type"] == "embeds": + pass else: raise ValueError(f"The {prompt_type} prompt cannot be empty") diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 126e7da70216..0f4c7517ebac 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] + output_embeds = [sample.output_embed for sample in valid_samples] # Truncate to max_tokens if necessary. remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + @@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor): is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. - for output_token_id, output_logprob in zip(output_token_ids, - output_logprobs): + for output_token_id, output_logprob, output_embed in zip( + output_token_ids, output_logprobs, output_embeds): seq.append_token_id( token_id=output_token_id, logprobs=output_logprob, + token_embed=output_embed, ) if is_prefill_sampled_token: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4d96791a1f8a..b5b51bb25a86 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): sample = outputs.samples[0] seq = seq_group.first_seq if not is_async: - seq.append_token_id(sample.output_token, sample.logprobs) + seq.append_token_id(sample.output_token, sample.logprobs, + sample.output_embed) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ca706e202836..9914a9dcffcc 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs, - TokensPrompt, build_explicit_enc_dec_prompt, + TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs, to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts) from .registry import (DummyData, InputContext, InputProcessingContext, InputRegistry) @@ -21,7 +21,9 @@ __all__ = [ "SingletonPrompt", "ExplicitEncoderDecoderPrompt", "TokenInputs", + "EmbedsInputs", "token_inputs", + "embeds_inputs", "DecoderOnlyInputs", "EncoderDecoderInputs", "ProcessorInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 167189ed108e..6a56d044c9f9 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -2,6 +2,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast +import torch from typing_extensions import NotRequired, TypedDict, TypeVar if TYPE_CHECKING: @@ -63,12 +64,20 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +class EmbedsPrompt(TypedDict): + """Schema for a prompt provided via token embeddings.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt] """ Set of possible schemas for a single prompt: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) Note that "singleton" is as opposed to a data structure which encapsulates multiple prompts, i.e. of the sort @@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types: - A text prompt (:class:`str` or :class:`TextPrompt`) - A tokenized prompt (:class:`TokensPrompt`) +- An embeddings prompt (:class:`EmbedsPrompt`) - A single data structure containing both an encoder and a decoder prompt (:class:`ExplicitEncoderDecoderPrompt`) """ @@ -176,7 +186,27 @@ def token_inputs( return inputs -DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"] +class EmbedsInputs(TypedDict): + """Represents embeddings-based inputs.""" + + type: Literal["embeds"] + """The type of inputs.""" + + prompt_embeds: torch.Tensor + """The embeddings of the prompt.""" + + +def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs: + """Construct :class:`EmbedsInputs` from optional values.""" + inputs = EmbedsInputs( + type="embeds", + prompt_embeds=prompt_embeds, + ) + + return inputs + + +DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ The inputs in :class:`~vllm.LLMEngine` before they are passed to the model executor. @@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict): """The inputs for the decoder portion.""" -SingletonInputs = Union[TokenInputs, "MultiModalInputs"] +SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"] """ A processed :class:`SingletonPrompt` which can be passed to :class:`vllm.sequence.Sequence`. diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index 28e207de1fd3..397344e40230 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -6,8 +6,9 @@ from typing_extensions import TypeIs from vllm.utils import is_list_of -from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType, - SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt) +from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TextPrompt, TokensPrompt) class ParsedText(TypedDict): @@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict): content: TokensPrompt +class ParsedEmbedsPrompt(TypedDict): + type: Literal['embeds'] + content: EmbedsPrompt + + +@overload +def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt: + ... + + +@overload +def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt: + ... + + def parse_singleton_prompt( prompt: SingletonPrompt, -) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, + ParsedEmbedsPrompt]: if isinstance(prompt, str): return ParsedStrPrompt(type="str", content=prompt) elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: - return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore + # Type ignores are because mypy does not correctly infer the TypedDicts + # Pyright does succeed. + if "prompt_embeds" in prompt: + return ParsedEmbedsPrompt( + type="embeds", content=prompt) # type: ignore[typeddict-item] + elif "prompt_token_ids" in prompt: + return ParsedTokensPrompt( + type="tokens", content=prompt) # type: ignore[typeddict-item] elif "prompt" in prompt: return ParsedTextPrompt(type="text", content=prompt) - - raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") + raise TypeError( + "inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt") def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]: return isinstance(prompt, dict) and "prompt_token_ids" in prompt +def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]: + return isinstance(prompt, dict) and "prompt_embeds" in prompt + + def is_explicit_encoder_decoder_prompt( prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: return isinstance(prompt, dict) and "encoder_prompt" in prompt +def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]: + return isinstance(inputs, dict) and inputs["type"] == "embeds" + + def split_enc_dec_inputs( inputs: ProcessorInputs, ) -> tuple[Optional[SingletonInputs], SingletonInputs]: diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 83e6907f8c49..5a9e3643dcad 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast from typing_extensions import assert_never +from vllm import envs from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import TokenizerGroup -from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, - PromptType, SingletonInputs, SingletonPrompt, token_inputs) -from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt, +from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs, + ProcessorInputs, PromptType, SingletonInputs, + SingletonPrompt, TokenInputs, embeds_inputs, token_inputs) +from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt, + ParsedTokensPrompt, is_embeds_inputs, is_explicit_encoder_decoder_prompt, parse_singleton_prompt) logger = init_logger(__name__) @@ -328,6 +331,10 @@ class InputPreprocessor: * :class:`SingletonInputs` instance """ parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "embeds": + return self._process_prompt_embeds(parsed) + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ self._get_prompt_data(parsed) @@ -359,6 +366,8 @@ class InputPreprocessor: cache_salt=cache_salt, ) + assert_never(parsed) + async def _prompt_to_llm_inputs_async( self, prompt: SingletonPrompt, @@ -369,6 +378,9 @@ class InputPreprocessor: """Async version of :meth:`_extract_prompt_components`.""" parsed = parse_singleton_prompt(prompt) + if parsed["type"] == "embeds": + return self._process_prompt_embeds(parsed) + prompt_text, prompt_token_ids, token_type_ids, cache_salt = \ self._get_prompt_data(parsed) @@ -399,10 +411,34 @@ class InputPreprocessor: cache_salt=cache_salt, ) + def _process_prompt_embeds(self, + parsed: ParsedEmbedsPrompt) -> EmbedsInputs: + if envs.VLLM_USE_V1: + raise ValueError("prompt_embeds is only available in V0.") + + prompt_embeds_content = parsed["content"] + + prompt_embeds = prompt_embeds_content["prompt_embeds"] + + # prompt_embeds must be (seq_len, hidden_size), but if the user + # passes in a batch of size 1, i.e. (1, seq_len, hidden_size), + # we can unambiguously process the intent by squeezing the batch + # dimension. + if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1: + prompt_embeds = prompt_embeds.squeeze(dim=0) + + if prompt_embeds.ndim != 2: + raise ValueError( + "prompt_embeds must be of shape (seq_len, hidden_size).") + + return embeds_inputs(prompt_embeds=prompt_embeds) + + assert_never(parsed) + def _build_enc_dec_llm_inputs( self, - encoder_inputs: SingletonInputs, - decoder_inputs: Optional[SingletonInputs], + encoder_inputs: Union[TokenInputs, MultiModalInputs], + decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]], ) -> EncoderDecoderInputs: if (encoder_inputs["type"] == "token" or encoder_inputs["type"] == "multimodal"): @@ -410,6 +446,9 @@ class InputPreprocessor: else: assert_never(encoder_inputs) # type: ignore[arg-type] + # Mypy does not correctly infer that EmbedsInputs is impossible + assert "prompt_token_ids" in encoder_inputs + if decoder_inputs is None: if self.model_config.hf_config.model_type == "whisper": # For Whisper models, the text prompt should go to the decoder. @@ -441,7 +480,8 @@ class InputPreprocessor: def _separate_enc_dec_inputs_from_mm_processor_outputs( self, inputs: SingletonInputs, - decoder_inputs_to_override: Optional[SingletonInputs] = None, + decoder_inputs_to_override: Optional[Union[TokenInputs, + MultiModalInputs]] = None, ) -> tuple[SingletonInputs, SingletonInputs]: """ For encoder/decoder models only: @@ -540,6 +580,8 @@ class InputPreprocessor: # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: + assert decoder_inputs is None or not is_embeds_inputs( + decoder_inputs) encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) @@ -555,9 +597,12 @@ class InputPreprocessor: inputs)) else: encoder_inputs = inputs - decoder_inputs = None + # Mypy does not do type inference well with TypedDicts with Literal + # values. + assert not is_embeds_inputs(encoder_inputs) + assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) async def _process_encoder_decoder_prompt_async( @@ -590,6 +635,8 @@ class InputPreprocessor: # For multimodal model, override decoder prompt from processor # with explicit decoder prompt. if self.model_config.is_multimodal_model: + assert decoder_inputs is None or not is_embeds_inputs( + decoder_inputs) encoder_inputs, decoder_inputs = ( self._separate_enc_dec_inputs_from_mm_processor_outputs( encoder_inputs, decoder_inputs)) @@ -605,9 +652,12 @@ class InputPreprocessor: inputs)) else: encoder_inputs = inputs - decoder_inputs = None + # Mypy does not do type inference well with TypedDicts with Literal + # values. + assert not is_embeds_inputs(encoder_inputs) + assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs) return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) def _build_decoder_only_llm_inputs( @@ -617,10 +667,15 @@ class InputPreprocessor: ) -> DecoderOnlyInputs: if (prompt_inputs["type"] == "token" or prompt_inputs["type"] == "multimodal"): + # Mypy does not do type inference well with typedicts and Literal + # values + assert not is_embeds_inputs(prompt_inputs) prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter( prompt_inputs["prompt_token_ids"], prompt_adapter_request=prompt_adapter_request, ) + elif (prompt_inputs["type"] == "embeds"): + pass else: assert_never(prompt_inputs) # type: ignore[arg-type] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 1ee1332ac45e..9368992b24fe 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -110,6 +110,11 @@ class SamplerOutput( # 'broadcasted' to all other PP ranks for next step. sampled_token_ids_cpu: Optional[torch.Tensor] = None + # On-device tensor containing the sampled token embeddings (embeddings + # corresponding to the sampled token ids). Used when prompt embeddings are + # specified in lieu of prompt token ids or text. + sampled_token_embeds: Optional[torch.Tensor] = None + # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None @@ -183,7 +188,7 @@ class Sampler(nn.Module): # Whether or not the SamplerOutput should have on-device tensors # containing the sampled token ids and probabilities. This is used by - # speculative decoding. + # speculative decoding and when prompt embeddings are specified. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False diff --git a/vllm/sequence.py b/vllm/sequence.py index a97409523c94..5bc9b8a6fc82 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct, _output_token_ids: array = msgspec.field( default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + _prompt_embeds: Optional[torch.Tensor] = None + _output_embeds: Optional[torch.Tensor] = None + ### The below fields should not be passed as an argument ### _cumulative_logprob: float = 0.0 _prompt_token_ids_tuple: tuple[int, @@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct, _num_cached_tokens: int = 0 _stage: SequenceStage = SequenceStage.PREFILL _cached_all_token_ids: list[int] = msgspec.field(default_factory=list) + _cached_all_token_embeds: Optional[torch.Tensor] = None # It is used to get delta input. It is reset when `get_delta_and_reset` # is called. @@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct, def from_seqs( prompt_token_ids: GenericSequence[int], output_token_ids: Optional[GenericSequence[int]] = None, + *, + prompt_embeds: Optional[torch.Tensor] = None, ) -> "SequenceData": """ Construct a :class:`SequenceData` instance from prompt and output @@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct, prompt_token_ids) if output_token_ids is None: - return SequenceData(prompt_token_ids_arr) + return SequenceData(prompt_token_ids_arr, + _prompt_embeds=prompt_embeds) output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, output_token_ids) return SequenceData(prompt_token_ids_arr, - _output_token_ids=output_token_ids_arr) + _output_token_ids=output_token_ids_arr, + _prompt_embeds=prompt_embeds) def __post_init__(self) -> None: assert self._prompt_token_ids.typecode == "l" @@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct, self._prompt_token_ids_tuple: tuple[int, ...] = tuple( self._prompt_token_ids) self._update_cached_all_tokens() + if self._prompt_embeds is not None: + self._update_cached_all_token_embeds() def _update_cached_all_tokens(self): assert isinstance(self._prompt_token_ids, array) @@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct, self._cached_all_token_ids: list[int] = list(self._prompt_token_ids + self._output_token_ids) + def _update_cached_all_token_embeds(self): + assert isinstance(self._prompt_embeds, torch.Tensor) + self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds + if self._output_embeds is not None: + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, self._output_embeds), dim=0) + @property def cumulative_logprob(self) -> float: return self._cumulative_logprob @@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct, new_output_token_ids) self._update_cached_all_tokens() + @property + def output_embeds(self) -> Optional[torch.Tensor]: + return self._output_embeds + + @output_embeds.setter + def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None: + self._output_token_embeds = new_output_token_embeds + self._update_cached_all_token_embeds() + @property def output_token_ids_array(self) -> array: """Return the prompt token ids in array type. @@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct, assert isinstance(self._output_token_ids, array) return self._output_token_ids + @property + def prompt_embeds(self) -> Optional[torch.Tensor]: + return self._prompt_embeds + + @prompt_embeds.setter + def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None: + self._prompt_embeds = prompt_embeds + self._update_cached_all_token_embeds() + @property def mrope_position_delta(self) -> Optional[int]: return self._mrope_position_delta @@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct, def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta - def append_token_id(self, token_id: int, logprob: float) -> None: + def append_token_id(self, + token_id: int, + logprob: float, + token_embed: Optional[torch.Tensor] = None) -> None: self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) self._cumulative_logprob += logprob + if token_embed is not None: + # Do not pass in with batch or sequence dimensions + assert token_embed.ndim == 1 + token_embed = token_embed.detach().cpu().unsqueeze(0) + if self._output_embeds is None: + self._output_embeds = token_embed + else: + self._output_embeds = torch.cat( + (self._output_embeds, token_embed), dim=0) + assert self._cached_all_token_embeds is not None + self._cached_all_token_embeds = torch.cat( + (self._cached_all_token_embeds, + token_embed.to(device=self._cached_all_token_embeds.device)), + dim=0) def get_len(self) -> int: return len(self._output_token_ids) + len(self._prompt_token_ids) @@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct, def get_token_ids(self) -> list[int]: return self._cached_all_token_ids + def get_token_embeddings(self) -> Optional[torch.Tensor]: + return self._cached_all_token_embeds + def get_prefix_token_ids( self, num_tokens: int ) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]: @@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct, def __repr__(self) -> str: return (f"SequenceData(" f"prompt_token_ids={self._prompt_token_ids}, " + f"prompt_embeds.shape=" + f"{getattr(self._prompt_embeds, 'shape', None)}, " f"output_token_ids={self.output_token_ids}, " f"cumulative_logprob={self.cumulative_logprob}, " f"get_num_computed_tokens={self.get_num_computed_tokens()})") @@ -425,7 +482,10 @@ class Sequence: self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request - self.data = SequenceData.from_seqs(self.prompt_token_ids) + self.data = SequenceData.from_seqs( + self.prompt_token_ids, + prompt_embeds=self.inputs["prompt_embeds"] + if self.inputs["type"] == "embeds" else None) self.output_logprobs: SampleLogprobs = [] self.output_text = "" @@ -448,14 +508,20 @@ class Sequence: @property def prompt(self) -> Optional[str]: + if self.inputs["type"] == "embeds": + return None return self.inputs.get("prompt") @property def prompt_token_ids(self) -> list[int]: + if self.inputs["type"] == "embeds": + return [0] * len(self.inputs["prompt_embeds"]) return self.inputs["prompt_token_ids"] @property def token_type_ids(self) -> list[int]: + if self.inputs["type"] == "embeds": + return [] return self.inputs.get("token_type_ids", []) @property @@ -554,11 +620,14 @@ class Sequence: """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id(self, token_id: int, logprobs: dict[int, - Logprob]) -> None: + def append_token_id(self, + token_id: int, + logprobs: dict[int, Logprob], + token_embed: Optional[torch.Tensor] = None) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) - self.data.append_token_id(token_id, logprobs[token_id].logprob) + self.data.append_token_id(token_id, logprobs[token_id].logprob, + token_embed) def get_len(self) -> int: return self.data.get_len() @@ -889,6 +958,10 @@ class SequenceGroup: f"sampling_params={self.sampling_params}, " f"num_seqs={len(self.seqs)})") + def uses_prompt_embeds(self) -> bool: + """Returns True if the sequence group uses input embeds.""" + return any(seq.data.prompt_embeds is not None for seq in self.seqs) + class SequenceGroupMetadataDelta( msgspec.Struct, @@ -1043,10 +1116,14 @@ class SequenceOutput( parent_seq_id: int output_token: int logprobs: dict[int, Logprob] + output_embed: Optional[torch.Tensor] = None def __repr__(self) -> str: + output_embed_shape = \ + self.output_embed.shape if self.output_embed is not None else None return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " f"output_token={self.output_token}, " + f"output_embed.shape={output_embed_shape}" f"logprobs={self.logprobs})") def __eq__(self, other: object) -> bool: diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index 24095ef2a567..a6276c563394 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -201,6 +201,9 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): if self.prompt_adapter_config is not None: raise ValueError("TP1DraftModelRunner has no support for " "prompt_adapter_config") + if model_input.inputs_embeds is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "inputs_embeds") if model_input.multi_modal_kwargs: raise ValueError( "TP1DraftModelRunner has no support for multi_modal_kwargs" @@ -242,9 +245,16 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): # Get model if use_cuda_graph: - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = (self.graph_runners[model_input.virtual_engine] - [graph_batch_size]) + if model_input.inputs_embeds is None: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) if previous_hidden_states is not None: hidden_states = torch.cat([ @@ -281,6 +291,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase): self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=None, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 1146606e9a13..de57403d1b50 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -282,7 +282,8 @@ class MultiStepWorker(ProposerWorkerBase, DelegateWorkerBase): else: count += 1 - seq.append_token_id(token_id, token_logprob.logprob) + seq.append_token_id(token_id, token_logprob.logprob, + seq_output.output_embed) seq.update_num_computed_tokens(1) @staticmethod diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 4df192a8727c..4864163b0de2 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -49,6 +49,7 @@ class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, @@ -172,10 +173,17 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[ - model_input.virtual_engine][graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model @@ -189,6 +197,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): model_input.virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, encoder_input_ids=model_input.encoder_input_tokens, encoder_positions=model_input.encoder_input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 73e0eff9a8b7..85814e9af9e3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,7 +35,8 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + get_sampler) from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models import supports_lora, supports_multimodal @@ -85,6 +86,7 @@ class ModelInputForGPU(ModelRunnerInputBase): additional fields. """ input_tokens: Optional[torch.Tensor] = None + inputs_embeds: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None @@ -105,6 +107,7 @@ class ModelInputForGPU(ModelRunnerInputBase): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -155,6 +158,7 @@ class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, + "inputs_embeds": self.inputs_embeds, "input_positions": self.input_positions, "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, @@ -194,6 +198,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): def simple_reinit(self): self.input_tokens[0].clear() # type: ignore + self.inputs_embeds = None # type: ignore self.input_positions[0].clear() # type: ignore self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore @@ -221,6 +226,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, input_positions: Optional[List[List[int]]] = None, token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, @@ -282,6 +288,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): for seq_id in range(len(self.seq_ids)): self.input_tokens[seq_id].clear() + self.inputs_embeds = inputs_embeds + if input_positions: self.input_positions = input_positions else: @@ -356,6 +364,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): else: self.input_tokens = input_tokens or [] + self.inputs_embeds = inputs_embeds self.input_positions = input_positions or [] self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None @@ -401,6 +410,26 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): self.lora_index_mapping = [] self.lora_prompt_mapping = [] + def __repr__(self) -> str: + return (f"InterDataForSeqGroup(" + f"request_id={self.request_id}, " + f"seq_ids={self.seq_ids}, " + f"is_prompt={self.is_prompt}, " + f"block_tables={self.block_tables}, " + f"computed_block_nums={self.computed_block_nums}, " + f"n_seqs={self.n_seqs}, " + f"input_tokens={self.input_tokens}, " + f"inputs_embeds.shape=" + f"{getattr(self.inputs_embeds, 'shape', None)}, " + f"input_positions={self.input_positions}, " + f"token_types={self.token_types}, " + f"mrope_input_positions={self.mrope_input_positions}, " + f"seq_lens={self.seq_lens}, " + f"orig_seq_lens={self.orig_seq_lens}, " + f"query_lens={self.query_lens}, " + f"context_lens={self.context_lens}, " + f"multi_modal_kwargs={self.multi_modal_kwargs}") + def gen_inter_data_builder(self, num_seqs: int): return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( request_id="", @@ -511,13 +540,21 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): context_len = seq_data.get_num_computed_tokens() # Compute tokens. - tokens = seq_data.get_token_ids()[context_len:seq_len] + if seq_data.prompt_embeds is None: + tokens = seq_data.get_token_ids()[context_len:seq_len] + prompt_embeds = None + else: + tokens = [0] * (seq_len - context_len) + prompt_embeds = seq_data.get_token_embeddings( + )[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.inputs_embeds = prompt_embeds inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) inter_data.token_types[seq_idx].extend( token_types if token_types else []) @@ -822,15 +859,29 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): create on-device tensors. """ # Combine and flatten intermediate data. - input_tokens = [] - token_types = [] + input_tokens = list[int]() + inputs_embeds_lst = list[torch.Tensor]() + token_types = list[int]() for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) for cur_token_types in inter_data.token_types: token_types.extend(cur_token_types) + if inter_data.inputs_embeds is not None: + inputs_embeds_lst.append( + inter_data.inputs_embeds.to( + dtype=self.runner.model_config.dtype, + device=self.runner.device)) + inputs_embeds: Optional[torch.Tensor] + if len(inputs_embeds_lst) == 0: + inputs_embeds = None + else: + inputs_embeds = torch.cat(inputs_embeds_lst, dim=0).to( + dtype=self.runner.model_config.dtype, + device=self.runner.device) + assert len(inputs_embeds) == len(input_tokens) - if not input_tokens: + if not input_tokens and inputs_embeds is None: # This may happen when all prefill requests hit # prefix caching and there is no decode request. return self.model_input_cls() @@ -980,6 +1031,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): return self.model_input_cls( input_tokens=input_tokens_tensor, + inputs_embeds=inputs_embeds, input_positions=input_positions_tensor, token_types=token_types_tensor, attn_metadata=attn_metadata, @@ -1029,7 +1081,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): self.max_batchsize_to_capture = \ self.vllm_config.compilation_config.max_capture_size - self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + # + self.graph_runners: List[Dict[Tuple[int, bool], CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) ] self.graph_memory_pool: Optional[Tuple[ @@ -1466,6 +1519,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): input_positions = torch.zeros(max_batch_size, dtype=torch.long, device=self.device) + inputs_embeds = torch.zeros( + (max_batch_size, self.model_config.get_hidden_size()), + dtype=self.model_config.dtype, + device=self.device) if self.model_config.uses_mrope: input_positions = torch.tile(input_positions, (3, 1)).cuda(device=self.device) @@ -1503,15 +1560,22 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - # Only rank 0 should print progress bar during capture - cudagraph_capture_sizes = (tqdm( - self.vllm_config.compilation_config. + # We need to not only iterate over batch sizes, but also whether + # to use inputs_embeds or not, hence we use the cartesian + # product. + cudagraph_capture_sizes = self.vllm_config.compilation_config\ + .cudagraph_capture_sizes + cudagraph_inputs_embeds = (True, False) + compilation_cases = itertools.product( cudagraph_capture_sizes, - desc="Capturing CUDA graph shapes", - ) if get_tensor_model_parallel_rank() == 0 else - self.vllm_config.compilation_config. - cudagraph_capture_sizes) - for batch_size in cudagraph_capture_sizes: + cudagraph_inputs_embeds, + ) + # Only rank 0 should print progress bar during capture + if get_tensor_model_parallel_rank() == 0: + compilation_cases = tqdm( + list(compilation_cases), + desc="Capturing CUDA graph shapes") + for batch_size, use_inputs_embeds in compilation_cases: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1542,6 +1606,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): capture_inputs = { "input_ids": input_tokens[:batch_size], + "inputs_embeds": + inputs_embeds[:batch_size] + if use_inputs_embeds else None, "positions": input_positions[..., :batch_size], "intermediate_inputs": @@ -1578,8 +1645,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): virtual_engine): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() - self.graph_runners[virtual_engine][batch_size] = ( - graph_runner) + self.graph_runners[virtual_engine][( + batch_size, use_inputs_embeds)] = graph_runner if self.lora_config: self._remove_dummy_loras() @@ -1711,8 +1778,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): if prefill_meta is None and decode_meta.use_cuda_graph: assert model_input.input_tokens is not None graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + use_inputs_embeds = model_input.inputs_embeds is not None + model_executable = self.graph_runners[virtual_engine][( + graph_batch_size, use_inputs_embeds)] if previous_hidden_states is not None: previous_hidden_states = torch.cat([ previous_hidden_states, @@ -1763,6 +1831,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): self.vllm_config, virtual_engine): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, + inputs_embeds=model_input.inputs_embeds, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, @@ -1817,6 +1886,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): model_input.async_callback() # Sample the next token. + assert isinstance(self.sampler, Sampler) + orig_include_gpu_probs_tensor = self.sampler.include_gpu_probs_tensor + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = True + output: SamplerOutput = self.sampler( logits=logits, sampling_metadata=model_input.sampling_metadata, @@ -1838,6 +1912,18 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): output.model_forward_time = (orig_model_forward_time + model_forward_time) + if model_input.inputs_embeds is not None: + self.sampler.include_gpu_probs_tensor = \ + orig_include_gpu_probs_tensor + if output.sampled_token_ids is not None: + output.sampled_token_embeds = self.model.get_input_embeddings( + output.sampled_token_ids.squeeze(1)) + + for token_embed, sequence_group_output in zip( + output.sampled_token_embeds, output.outputs): + assert len(sequence_group_output.samples) == 1 + sequence_group_output.samples[0].output_embed = token_embed + if self.return_hidden_states: # we only need to pass hidden states of most recent token assert model_input.sampling_metadata is not None @@ -1931,6 +2017,7 @@ class CUDAGraphRunner(nn.Module): def capture( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_inputs: Optional[IntermediateTensors], kv_caches: List[torch.Tensor], @@ -1947,6 +2034,7 @@ class CUDAGraphRunner(nn.Module): for _ in range(_NUM_WARMUP_ITERS): self.model( input_ids=input_ids, + inputs_embeds=inputs_embeds, positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -1959,6 +2047,9 @@ class CUDAGraphRunner(nn.Module): with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): output_hidden_or_intermediate_states = self.model( input_ids=input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), positions=positions, intermediate_tensors=intermediate_inputs, **kwargs, @@ -1986,6 +2077,9 @@ class CUDAGraphRunner(nn.Module): self.input_buffers = { "input_ids": input_ids, + **({ + "inputs_embeds": inputs_embeds, + } if inputs_embeds is not None else {}), "positions": positions, "kv_caches": @@ -2006,6 +2100,7 @@ class CUDAGraphRunner(nn.Module): def forward( self, input_ids: torch.Tensor, + inputs_embeds: Optional[torch.Tensor], positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors], **kwargs, @@ -2020,6 +2115,9 @@ class CUDAGraphRunner(nn.Module): # so the shape is not padded, we need to copy partial only self.input_buffers["positions"][:positions.shape[0]].copy_( positions, non_blocking=True) + if inputs_embeds is not None: + self.input_buffers["inputs_embeds"][:inputs_embeds.shape[0]].copy_( + inputs_embeds, non_blocking=True) if self.backend_name != "NO_ATTENTION": self.input_buffers["slot_mapping"].copy_( diff --git a/vllm/worker/pooling_model_runner.py b/vllm/worker/pooling_model_runner.py index cbd5e2060cad..fdb7353f2f9c 100644 --- a/vllm/worker/pooling_model_runner.py +++ b/vllm/worker/pooling_model_runner.py @@ -84,10 +84,17 @@ class PoolingModelRunner( # explore how to leverage it. if (prefill_meta is None and decode_meta is not None and decode_meta.use_cuda_graph): - assert model_input.input_tokens is not None - graph_batch_size = model_input.input_tokens.shape[0] - model_executable = self.graph_runners[virtual_engine][ - graph_batch_size] + if model_input.inputs_embeds is None: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, False)]) + else: + graph_batch_size = model_input.inputs_embeds.shape[0] + model_executable = ( + self.graph_runners[model_input.virtual_engine][( + graph_batch_size, True)]) else: model_executable = self.model