diff --git a/requirements/common.txt b/requirements/common.txt index b8665104bd09a..7973da080c37d 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -24,7 +24,7 @@ outlines_core == 0.2.11 # required for outlines backend disk cache diskcache == 5.6.3 lark == 1.2.2 -xgrammar == 0.1.23; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" +xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index fba18f197074b..24b1c9a93126c 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -76,11 +76,6 @@ def test_models( model_executor: str, enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if not envs.VLLM_USE_V1: if async_scheduling: pytest.skip("async_scheduling only supported in v1.") @@ -164,11 +159,6 @@ def test_models_distributed( extra_env: dict[str, str], enable_prompt_embeds: bool, ) -> None: - - if enable_prompt_embeds and envs.is_set( - "VLLM_USE_V1") and envs.VLLM_USE_V1: - pytest.skip("enable_prompt_embeds is not supported in v1.") - if test_suite != TARGET_TEST_SUITE: pytest.skip(f"Skip test for {test_suite}") diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py index 3d56291bc793c..0e3fc82f0c033 100644 --- a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -36,7 +36,6 @@ def default_server_args() -> list[str]: "--enforce-eager", # Prompt Embeds server args "--enable-prompt-embeds", - "--no-enable-chunked-prefill", ] diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index eceaff672112f..8d974d56b4450 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -287,6 +287,57 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str): assert response3.status == "completed" +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_streaming_types(client: OpenAI, model_name: str): + prompts = [ + "tell me a story about a cat in 20 words", + ] + + # this links the "done" type with the "start" type + # so every "done" type should have a corresponding "start" type + # and every open block should be closed by the end of the stream + pairs_of_event_types = { + "response.completed": "response.created", + "response.output_item.done": "response.output_item.added", + "response.content_part.done": "response.content_part.added", + "response.output_text.done": "response.output_text.delta", + "response.web_search_call.done": "response.web_search_call.added", + "response.reasoning_text.done": "response.reasoning_text.delta", + "response.reasoning_part.done": "response.reasoning_part.added", + } + + for prompt in prompts: + response = await client.responses.create( + model=model_name, + input=prompt, + reasoning={"effort": "low"}, + tools=[], + stream=True, + background=False, + ) + + stack_of_event_types = [] + async for event in response: + if event.type == 'response.created': + stack_of_event_types.append(event.type) + elif event.type == 'response.completed': + assert stack_of_event_types[-1] == pairs_of_event_types[ + event.type] + stack_of_event_types.pop() + if event.type.endswith("added"): + stack_of_event_types.append(event.type) + elif event.type.endswith("delta"): + if stack_of_event_types[-1] == event.type: + continue + stack_of_event_types.append(event.type) + elif event.type.endswith("done"): + assert stack_of_event_types[-1] == pairs_of_event_types[ + event.type] + stack_of_event_types.pop() + assert len(stack_of_event_types) == 0 + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("background", [True, False]) @@ -343,7 +394,10 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool): assert event.item_id == current_item_id # verify content_index_id is correct - if event.type == "response.content_part.added": + if event.type in [ + "response.content_part.added", + "response.reasoning_part.added" + ]: assert event.content_index != current_content_index current_content_index = event.content_index elif event.type in [ diff --git a/tests/models/language/generation/test_common.py b/tests/models/language/generation/test_common.py index a5aa1e3f49743..c14e71cbdb96d 100644 --- a/tests/models/language/generation/test_common.py +++ b/tests/models/language/generation/test_common.py @@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") - # Note: can be removed when - # https://github.com/vllm-project/vllm/pull/24278 finished - if current_platform.is_cpu() and use_prompt_embeds: - pytest.skip("Skipping use_prompt_embeds=True with " - "V1-only CPU backend.") - with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) diff --git a/tests/v1/kv_offload/test_cpu.py b/tests/v1/kv_offload/test_cpu.py new file mode 100644 index 0000000000000..cdee7811d85b3 --- /dev/null +++ b/tests/v1/kv_offload/test_cpu.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Optional + +import numpy as np + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, + PrepareStoreOutput) +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec + + +@dataclass +class ExpectedPrepareStoreOutput: + block_hashes_to_store: list[int] + store_block_ids: list[int] + block_hashes_evicted: list[int] + + +def to_hashes(int_hashes: list[int]) -> list[BlockHash]: + return [BlockHash(str(i).encode()) for i in int_hashes] + + +def verify_store_output( + prepare_store_output: Optional[PrepareStoreOutput], + expected_prepare_store_output: ExpectedPrepareStoreOutput): + assert prepare_store_output is not None + assert (prepare_store_output.block_hashes_to_store == to_hashes( + expected_prepare_store_output.block_hashes_to_store)) + assert (prepare_store_output.block_hashes_evicted == to_hashes( + expected_prepare_store_output.block_hashes_evicted)) + store_spec = prepare_store_output.store_spec + assert isinstance(store_spec, CPULoadStoreSpec) + expected_array = np.array(expected_prepare_store_output.store_block_ids, + dtype=np.int64) + assert np.array_equal(expected_array, store_spec.block_ids) + + +def verify_load_output(prepare_load_output: LoadStoreSpec, + expected_prepare_load_output: list[int]): + assert isinstance(prepare_load_output, CPULoadStoreSpec) + expected_array = np.array(expected_prepare_load_output, dtype=np.int64) + assert np.array_equal(expected_array, prepare_load_output.block_ids) + + +def verify_events(events: Iterable[OffloadingEvent], + block_size: int, + expected_stores: tuple[set[int], ...] = (), + expected_evictions: tuple[set[int], ...] = ()): + stores: list[set[BlockHash]] = [] + evictions: list[set[BlockHash]] = [] + for event in events: + assert event.medium == CPULoadStoreSpec.medium() + assert event.block_size == block_size + if event.removed: + evictions.append(set(event.block_hashes)) + else: + stores.append(set(event.block_hashes)) + + def to_hash_sets( + int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]: + return tuple([set(to_hashes(list(int_set))) for int_set in int_sets]) + + assert tuple(evictions) == to_hash_sets(expected_evictions) + assert tuple(stores) == to_hash_sets(expected_stores) + + +def test_cpu_manager(): + """ + Tests LRUOffloadingManager with a CPUBackend. + """ + # initialize a CPU backend with a capacity of 4 blocks + block_size = 256 + cpu_backend = CPUBackend(block_size=block_size, num_blocks=4) + cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True) + + # prepare store [1, 2] + prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[1, 2], + store_block_ids=[0, 1], + block_hashes_evicted=[], + )) + + # lookup [1, 2] -> not ready + assert cpu_manager.lookup(to_hashes([1, 2])) == 0 + + # no events so far + assert list(cpu_manager.take_events()) == [] + + # complete store [1, 2] + cpu_manager.complete_store(to_hashes([1, 2])) + verify_events(cpu_manager.take_events(), + block_size=block_size, + expected_stores=({1, 2}, )) + + # lookup [1, 2] + assert cpu_manager.lookup(to_hashes([1])) == 1 + assert cpu_manager.lookup(to_hashes([1, 2])) == 2 + assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2 + + # prepare store [2, 3, 4, 5] -> evicts [1] + prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[3, 4, 5], + store_block_ids=[2, 3, 0], + block_hashes_evicted=[1], + )) + + # verify eviction event + verify_events(cpu_manager.take_events(), + block_size=block_size, + expected_evictions=({1}, )) + + # prepare store with no space + assert cpu_manager.prepare_store(to_hashes([1, 6])) is None + + # complete store [2, 3, 4, 5] + cpu_manager.complete_store(to_hashes([2, 3, 4, 5])) + + # prepare load [2, 3] + prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3])) + verify_load_output(prepare_load_output, [1, 2]) + + # prepare store with no space ([2, 3] is being loaded) + assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None + + # complete load [2, 3] + cpu_manager.complete_load(to_hashes([2, 3])) + + # prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest) + prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[6, 7, 8], + store_block_ids=[3, 2, 1], + block_hashes_evicted=[2, 3, 4], + )) + + # complete store [6, 7, 8] + cpu_manager.complete_store(to_hashes([6, 7, 8])) + + # touch [5, 6, 7] (move to end of LRU order) + cpu_manager.touch(to_hashes([5, 6, 7])) + + # prepare store [7, 9] -> evicts [8] (oldest following previous touch) + prepare_store_output = cpu_manager.prepare_store(to_hashes([9])) + verify_store_output( + prepare_store_output, + ExpectedPrepareStoreOutput( + block_hashes_to_store=[9], + store_block_ids=[1], + block_hashes_evicted=[8], + )) + + # complete store [7, 9] with failure + cpu_manager.complete_store(to_hashes([7, 9]), success=False) + + # assert [7] is still stored, but [9] is not + assert cpu_manager.lookup(to_hashes([7])) == 1 + assert cpu_manager.lookup(to_hashes([9])) == 0 + + verify_events(cpu_manager.take_events(), + block_size=block_size, + expected_stores=({3, 4, 5}, {6, 7, 8}), + expected_evictions=({2, 3, 4}, {8})) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fb5beab77b270..63282c4253509 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1513,12 +1513,6 @@ class EngineArgs: recommend_to_remove=False) return False - # No text embedding inputs so far. - if self.enable_prompt_embeds: - _raise_or_fallback(feature_name="--enable-prompt-embeds", - recommend_to_remove=False) - return False - # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: _raise_or_fallback(feature_name=model_config.architectures, @@ -1651,6 +1645,13 @@ class EngineArgs: "models in V0 and has been disabled.") self.enable_prefix_caching = False + if self.enable_prompt_embeds: + logger.warning( + "--enable-prompt-embeds and --enable-prefix-caching " + "are not supported together in V0. Prefix caching has " + "been disabled.") + self.enable_prefix_caching = False + # Set max_num_seqs to 256 for VLLM_V0. if self.max_num_seqs is None: self.max_num_seqs = 256 @@ -1664,6 +1665,17 @@ class EngineArgs: # For pooling tasks the default is False if model_config.runner_type != "pooling": self.enable_chunked_prefill = True + + # TODO: When prefix caching supports prompt embeds inputs, this + # check can be removed. + if (self.enable_prompt_embeds + and self.enable_prefix_caching is not False): + logger.warning( + "--enable-prompt-embeds and --enable-prefix-caching " + "are not supported together in V1. Prefix caching has " + "been disabled.") + self.enable_prefix_caching = False + if self.enable_prefix_caching is None: self.enable_prefix_caching = True else: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7ad8e73d89d59..05d5d6d964dd3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -31,6 +31,8 @@ from openai.types.responses import ( ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent, ResponseStatus, ResponseWebSearchCallCompletedEvent, ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent) +from openai.types.responses.response_reasoning_item import ( + Content as ResponseReasoningTextContent) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) @@ -260,26 +262,6 @@ ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall] -StreamingResponsesResponse: TypeAlias = Union[ - ResponseCreatedEvent, - ResponseInProgressEvent, - ResponseCompletedEvent, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseCodeInterpreterCallInProgressEvent, - ResponseCodeInterpreterCallCodeDeltaEvent, - ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent, - ResponseWebSearchCallCompletedEvent, - ResponseCodeInterpreterCallCodeDoneEvent, - ResponseCodeInterpreterCallInterpretingEvent, - ResponseCodeInterpreterCallCompletedEvent, -] - class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -973,7 +955,6 @@ class CompletionRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None - prompt_embeds: Optional[Union[bytes, list[bytes]]] = None best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 @@ -1009,6 +990,7 @@ class CompletionRequest(OpenAIBaseModel): # --8<-- [end:completion-sampling-params] # --8<-- [start:completion-extra-params] + prompt_embeds: Optional[Union[bytes, list[bytes]]] = None add_special_tokens: bool = Field( default=True, description=( @@ -1916,6 +1898,72 @@ class ResponsesResponse(OpenAIBaseModel): ) +# TODO: this code can be removed once +# https://github.com/openai/openai-python/issues/2634 has been resolved +class ResponseReasoningPartDoneEvent(OpenAIBaseModel): + content_index: int + """The index of the content part that is done.""" + + item_id: str + """The ID of the output item that the content part was added to.""" + + output_index: int + """The index of the output item that the content part was added to.""" + + part: ResponseReasoningTextContent + """The content part that is done.""" + + sequence_number: int + """The sequence number of this event.""" + + type: Literal["response.reasoning_part.done"] + """The type of the event. Always `response.reasoning_part.done`.""" + + +# TODO: this code can be removed once +# https://github.com/openai/openai-python/issues/2634 has been resolved +class ResponseReasoningPartAddedEvent(OpenAIBaseModel): + content_index: int + """The index of the content part that is done.""" + + item_id: str + """The ID of the output item that the content part was added to.""" + + output_index: int + """The index of the output item that the content part was added to.""" + + part: ResponseReasoningTextContent + """The content part that is done.""" + + sequence_number: int + """The sequence number of this event.""" + + type: Literal["response.reasoning_part.added"] + """The type of the event. Always `response.reasoning_part.added`.""" + + +StreamingResponsesResponse: TypeAlias = Union[ + ResponseCreatedEvent, + ResponseInProgressEvent, + ResponseCompletedEvent, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, + ResponseCodeInterpreterCallInProgressEvent, + ResponseCodeInterpreterCallCodeDeltaEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, + ResponseWebSearchCallCompletedEvent, + ResponseCodeInterpreterCallCodeDoneEvent, + ResponseCodeInterpreterCallInterpretingEvent, + ResponseCodeInterpreterCallCompletedEvent, +] + BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest] diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index 469d74272b0e6..4894623aeac28 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse, InputTokensDetails, OutputTokensDetails, RequestResponseMetadata, + ResponseReasoningPartAddedEvent, + ResponseReasoningPartDoneEvent, ResponsesRequest, ResponsesResponse, ResponseUsage, StreamingResponsesResponse) @@ -1280,14 +1282,13 @@ class OpenAIServingResponses(OpenAIServing): # Deal with tool call here pass elif previous_item.channel == "analysis": + content = ResponseReasoningTextContent( + text=previous_item.content[0].text, + type="reasoning_text", + ) reasoning_item = ResponseReasoningItem( type="reasoning", - content=[ - ResponseReasoningTextContent( - text=previous_item.content[0].text, - type="reasoning_text", - ), - ], + content=[content], status="completed", id=current_item_id, summary=[], @@ -1301,6 +1302,15 @@ class OpenAIServingResponses(OpenAIServing): content_index=current_content_index, text=previous_item.content[0].text, )) + yield _increment_sequence_number_and_return( + ResponseReasoningPartDoneEvent( + type="response.reasoning_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=content, + )) yield _increment_sequence_number_and_return( ResponseOutputItemDoneEvent( type="response.output_item.done", @@ -1412,17 +1422,15 @@ class OpenAIServingResponses(OpenAIServing): )) current_content_index += 1 yield _increment_sequence_number_and_return( - ResponseContentPartAddedEvent( - type="response.content_part.added", + ResponseReasoningPartAddedEvent( + type="response.reasoning_part.added", sequence_number=-1, output_index=current_output_index, item_id=current_item_id, content_index=current_content_index, - part=ResponseOutputText( - type="output_text", + part=ResponseReasoningTextContent( text="", - annotations=[], - logprobs=[], + type="reasoning_text", ), )) yield _increment_sequence_number_and_return( diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index b1d7f24c2f18b..2770ddebc48ab 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -229,14 +229,15 @@ class EagleDeepseekV3ForCausalLM(DeepseekV3ForCausalLM): return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/models/llama4_eagle.py b/vllm/model_executor/models/llama4_eagle.py index ece490ff2f2a8..a203af53205cd 100644 --- a/vllm/model_executor/models/llama4_eagle.py +++ b/vllm/model_executor/models/llama4_eagle.py @@ -205,23 +205,21 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: + + def transform(inputs): + name, loaded_weight = inputs + name, weight = self.permute_qk_weight_for_rotary( + name, loaded_weight) + if "lm_head" not in name: + name = "model." + name + return name, weight + loader = AutoWeightsLoader( self, # lm_head is tied with target model (Llama4ForCausalLM) skip_prefixes=(["lm_head."]), ) - - model_weights = {} - weights = [ - self.permute_qk_weight_for_rotary(name, loaded_weight) - for name, loaded_weight in weights - ] - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) def get_input_embeddings( self, diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py index a4933b77e3a53..dfae3c3ea5437 100644 --- a/vllm/model_executor/models/llama_eagle.py +++ b/vllm/model_executor/models/llama_eagle.py @@ -158,14 +158,15 @@ class EagleLlamaForCausalLM(LlamaForCausalLM): return self.model(input_ids, positions, hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + + def transform(inputs): + name, loaded_weight = inputs + if "lm_head" not in name: + name = "model." + name + return name, loaded_weight + loader = AutoWeightsLoader( self, skip_prefixes=None, ) - - model_weights = {} - for name, loaded_weight in weights: - if "lm_head" not in name: - name = "model." + name - model_weights[name] = loaded_weight - loader.load_weights(model_weights.items()) + loader.load_weights(map(transform, weights)) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 221712ba9a338..03e5e5809b678 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -12,7 +12,6 @@ from torch.nn import Parameter from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger -from vllm.model_executor.utils import _make_synced_weight_loader __all__ = [ "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", @@ -53,8 +52,9 @@ class BasevLLMParameter(Parameter): # This sometimes causes OOM errors during model loading. To avoid this, # we sync the param tensor after its weight loader is called. from vllm.platforms import current_platform - if current_platform.is_tpu(): - weight_loader = _make_synced_weight_loader(weight_loader) + if current_platform.use_sync_weight_loader(): + weight_loader = current_platform.make_synced_weight_loader( + weight_loader) self._weight_loader = weight_loader self.tp_rank = get_tensor_model_parallel_rank() diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 65436786f82ac..543918418953b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -44,23 +44,12 @@ def set_weight_attrs( # TODO(woosuk): Remove this hack once we have a better solution. from vllm.platforms import current_platform - if current_platform.is_tpu() and key == "weight_loader": - value = _make_synced_weight_loader(value) + if current_platform.use_sync_weight_loader( + ) and key == "weight_loader": + value = current_platform.make_synced_weight_loader(value) setattr(weight, key, value) -def _make_synced_weight_loader(original_weight_loader): - - def _synced_weight_loader(param, *args, **kwargs): - out = original_weight_loader(param, *args, **kwargs) - # torch._sync doesn't support, is not needed for CPU tensors. - if param.device != torch.device("cpu"): - torch._sync(param) - return out - - return _synced_weight_loader - - def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]: parent_map = getattr(model, "packed_modules_mapping", None) parent_map = copy.deepcopy(parent_map) if parent_map is not None else {} diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 240e34e139cfe..e00c10fb66eeb 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -569,8 +569,8 @@ class MultiModalFieldConfig: Args: modality: The modality of the multi-modal item that uses this keyword argument. - slices: For each multi-modal item, the size of the slice that - is used to extract the data corresponding to it. + size_per_item: For each multi-modal item, the size of the slice + that is used to extract the data corresponding to it. dim: The dimension to slice, default to 0. Example: @@ -590,7 +590,7 @@ class MultiModalFieldConfig: ``` Given: - slices: [3, 4, 2] + size_per_item: [3, 4, 2] dim: 1 Input: diff --git a/vllm/multimodal/profiling.py b/vllm/multimodal/profiling.py index fbbc55d3524ca..9b463e212bb49 100644 --- a/vllm/multimodal/profiling.py +++ b/vllm/multimodal/profiling.py @@ -234,19 +234,6 @@ class MultiModalProfiler(Generic[_I]): prompt_token_ids = mm_inputs["prompt_token_ids"] total_len = len(prompt_token_ids) - # V0 does not support chunked prefill. - if total_len > seq_len and not envs.VLLM_USE_V1: - # `max_num_batched_tokens` is defined by `SchedulerConfig` - logger.warning_once( - "The sequence length used for profiling (max_num_batched_tokens / max_num_seqs = %d) " # noqa: E501 - "is too short to hold the multi-modal embeddings in the worst case (%d tokens in total, out of which %s are reserved for multi-modal embeddings). " # noqa: E501 - "This may cause certain multi-modal inputs to fail during inference, even when the input text is short. " # noqa: E501 - "To avoid this, you should increase `max_model_len`, reduce `max_num_seqs`, and/or reduce `mm_counts`.", # noqa: E501 - seq_len, - total_len, - str(self._get_mm_num_tokens(mm_inputs)), - ) - if total_len < seq_len: prompt_token_ids.extend([0] * (seq_len - total_len)) @@ -270,22 +257,6 @@ class MultiModalProfiler(Generic[_I]): mm_counts=mm_counts, ) if max_tokens_per_item is not None: - if mm_counts is None: - total_mm_tokens = sum(max_tokens_per_item.values()) - else: - total_mm_tokens = sum(max_tokens_per_item[k] * mm_counts[k] - for k in max_tokens_per_item.keys() - & mm_counts.keys()) - if total_mm_tokens > seq_len: - logger.warning_once( - "The sequence length (%d) is smaller than the pre-defined" - " worst-case total number of multimodal tokens (%d). " - "This may cause certain multi-modal inputs to fail during " - "inference. To avoid this, you should increase " - "`max_model_len` or reduce `mm_counts`.", - seq_len, - total_mm_tokens, - ) return max_tokens_per_item mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index b308366fca282..f4e2ed72e2d7d 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -395,7 +395,9 @@ def group_mm_kwargs_by_modality( modality together into the same `MultiModalKwargs` instance. Args: - mm_inputs: List of `MultiModalKwargsItem`. + mm_kwargs: List of `MultiModalKwargsItem`. + device: The device to place the grouped tensors on. + pin_memory: Whether to pin memory for faster host-to-device transfer. Yields: A tuple `(modality, num_items, grouped_kwargs)`. diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 054d08c3a85be..53fc762dce540 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -594,6 +594,29 @@ class Platform: """ return False + @classmethod + def use_sync_weight_loader(cls) -> bool: + """ + Returns if the current platform needs to sync weight loader. + """ + return False + + @classmethod + def make_synced_weight_loader(cls, original_weight_loader): + """ + Wrap the original weight loader to make it synced. + """ + if not cls.use_sync_weight_loader(): + return original_weight_loader + + def _synced_weight_loader(param, *args, **kwargs): + out = original_weight_loader(param, *args, **kwargs) + if param.device != torch.device("cpu"): + torch._sync(param) + return out + + return _synced_weight_loader + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 6a061956d8141..4e4db116abca0 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -226,6 +226,10 @@ class TpuPlatform(Platform): torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() + @classmethod + def use_sync_weight_loader(cls) -> bool: + return True + try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index f13381ecd9ff3..d4013a69e99fe 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -3443,3 +3443,30 @@ def decorate_logs(process_name: Optional[str] = None) -> None: pid = os.getpid() _add_prefix(sys.stdout, process_name, pid) _add_prefix(sys.stderr, process_name, pid) + + +def length_from_prompt_token_ids_or_embeds( + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], +) -> int: + """Calculate the request length (in number of tokens) give either + prompt_token_ids or prompt_embeds. + """ + prompt_token_len = None if prompt_token_ids is None else len( + prompt_token_ids) + prompt_embeds_len = \ + None if prompt_embeds is None else len(prompt_embeds) + + if prompt_token_len is None: + if prompt_embeds_len is None: + raise ValueError( + "Neither prompt_token_ids nor prompt_embeds were defined.") + return prompt_embeds_len + else: + if (prompt_embeds_len is not None + and prompt_embeds_len != prompt_token_len): + raise ValueError( + "Prompt token ids and prompt embeds had different lengths" + f" prompt_token_ids={prompt_token_len}" + f" prompt_embeds={prompt_embeds_len}") + return prompt_token_len diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index 3ec5b91bf2860..209fc2a4404f3 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -11,6 +11,7 @@ from vllm._bc_linter import bc_linter_include if TYPE_CHECKING: import numpy as np import numpy.typing as npt + import torch from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorMetadata) @@ -26,13 +27,14 @@ if TYPE_CHECKING: class NewRequestData: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] block_ids: tuple[list[int], ...] num_computed_tokens: int lora_request: Optional[LoRARequest] + prompt_embeds: Optional[torch.Tensor] = None @classmethod def from_request( @@ -49,9 +51,12 @@ class NewRequestData: block_ids=block_ids, num_computed_tokens=request.num_computed_tokens, lora_request=request.lora_request, + prompt_embeds=request.prompt_embeds, ) - def __repr__(self): + def __repr__(self) -> str: + prompt_embeds_shape = (self.prompt_embeds.shape + if self.prompt_embeds else None) return (f"NewRequestData(" f"req_id={self.req_id}," f"prompt_token_ids={self.prompt_token_ids}," @@ -59,19 +64,26 @@ class NewRequestData: f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" ")") # Version of __repr__ with the prompt data obfuscated - def anon_repr(self): + def anon_repr(self) -> str: + prompt_token_ids_len = len( + self.prompt_token_ids + ) if self.prompt_token_ids is not None else None + prompt_embeds_shape = (self.prompt_embeds.shape + if self.prompt_embeds else None) return (f"NewRequestData(" f"req_id={self.req_id}," - f"prompt_token_ids_len={len(self.prompt_token_ids)}," + f"prompt_token_ids_len={prompt_token_ids_len}," f"mm_features={self.mm_features}," f"sampling_params={self.sampling_params}," f"block_ids={self.block_ids}," f"num_computed_tokens={self.num_computed_tokens}," - f"lora_request={self.lora_request}" + f"lora_request={self.lora_request}," + f"prompt_embeds_shape={prompt_embeds_shape}" ")") diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index dec4abec519bd..345f5a464c2cc 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -47,7 +47,7 @@ class EngineCoreRequest( gc=False): # type: ignore[call-arg] request_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: Optional[list[MultiModalFeatureSpec]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] @@ -56,6 +56,7 @@ class EngineCoreRequest( lora_request: Optional[LoRARequest] cache_salt: Optional[str] data_parallel_rank: Optional[int] + prompt_embeds: Optional[torch.Tensor] = None # Index of the client, used to ensure outputs are sent back to the same # client for this request when scaling out the front-end. diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index cf4b06db843bd..8aa36d6a439c1 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -13,6 +13,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest logger = init_logger(__name__) @@ -179,11 +180,12 @@ class FastIncrementalDetokenizer(BaseIncrementalDetokenizer): self.tokenizer: Tokenizer = tokenizer._tokenizer # Find a safe place to start. - prompt_suffix = request.prompt_token_ids + prompt_token_ids = request.prompt_token_ids or [] + prompt_suffix = prompt_token_ids prompt_len = len(prompt_suffix) if prompt_len > 4: for i in range(4, min(prompt_len + 1, 24)): - suffix = request.prompt_token_ids[-i:] + suffix = prompt_token_ids[-i:] if '�' not in self.tokenizer.decode(suffix): prompt_suffix = suffix break @@ -260,16 +262,25 @@ class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer): params = request.sampling_params assert params is not None - # Metadata for incremental detokenization. - self.tokens, self.prefix_offset, self.read_offset = ( - convert_prompt_ids_to_tokens( - tokenizer=tokenizer, - prompt_ids=request.prompt_token_ids, - skip_special_tokens=params.skip_special_tokens, - )) + self.prompt_len = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) - self.token_ids.extend(request.prompt_token_ids) - self.prompt_len = len(request.prompt_token_ids) + # Metadata for incremental detokenization. + if request.prompt_token_ids is not None: + self.tokens, self.prefix_offset, self.read_offset = ( + convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=request.prompt_token_ids, + skip_special_tokens=params.skip_special_tokens, + )) + else: + # Prompt embedding requests cannot be detokenized, in general. + self.tokens = [""] * self.prompt_len + self.prefix_offset = 0 + self.read_offest = 0 + + self.token_ids.extend(request.prompt_token_ids + or [0] * self.prompt_len) self.skip_special_tokens = params.skip_special_tokens self.spaces_between_special_tokens = ( diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 5dad63988daa4..c17dc3e204ecd 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -14,6 +14,7 @@ from vllm.sampling_params import RequestOutputKind from vllm.tracing import (SpanAttributes, SpanKind, Tracer, extract_trace_context) from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor @@ -86,7 +87,8 @@ class RequestState: lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], + prompt_embeds: Optional[torch.Tensor], logprobs_processor: Optional[LogprobsProcessor], detokenizer: Optional[IncrementalDetokenizer], max_tokens_param: Optional[int], @@ -104,7 +106,9 @@ class RequestState: self.output_kind = output_kind self.prompt = prompt self.prompt_token_ids = prompt_token_ids - self.prompt_len = len(prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.prompt_len = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer self.max_tokens_param = max_tokens_param @@ -165,6 +169,7 @@ class RequestState: output_kind=output_kind, prompt=prompt, prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, logprobs_processor=logprobs_processor, detokenizer=detokenizer, max_tokens_param=max_tokens_param, @@ -223,6 +228,8 @@ class RequestState: first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 + # Prompt embeddings are currently not supported by pooling requests. + assert self.prompt_token_ids is not None return PoolingRequestOutput( request_id=request_id, outputs=first_output, @@ -236,10 +243,15 @@ class RequestState: else: prompt_logprobs = self.logprobs_processor.prompt_logprobs + # If prompt embeds were used, put placeholder prompt token ids + prompt_token_ids = self.prompt_token_ids + if prompt_token_ids is None and self.prompt_embeds is not None: + prompt_token_ids = [0] * len(self.prompt_embeds) + return RequestOutput( request_id=request_id, prompt=self.prompt, - prompt_token_ids=self.prompt_token_ids, + prompt_token_ids=prompt_token_ids, prompt_logprobs=prompt_logprobs, outputs=cast(list[CompletionOutput], outputs), finished=finished, @@ -469,6 +481,8 @@ class OutputProcessor: arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) trace_context = extract_trace_context(engine_core_output.trace_headers) + prompt_length = length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds) with (self.tracer.start_as_current_span( "llm_request", kind=SpanKind.SERVER, @@ -488,7 +502,7 @@ class OutputProcessor: span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time) span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, - len(req_state.prompt_token_ids)) + prompt_length) span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, metrics.num_generation_tokens) span.set_attribute( @@ -544,7 +558,8 @@ class OutputProcessor: assert req_state.stats is not None iteration_stats.update_from_finished_request( finish_reason=finish_reason, - num_prompt_tokens=len(req_state.prompt_token_ids), + num_prompt_tokens=length_from_prompt_token_ids_or_embeds( + req_state.prompt_token_ids, req_state.prompt_embeds), max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats) self.lora_states.finish_request(req_state) diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 71f539583a1be..507e2cd3223fd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import EngineCoreRequest from vllm.v1.structured_output.backend_guidance import ( validate_guidance_grammar) @@ -390,6 +391,16 @@ class Processor: self._validate_model_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) + # Mypy does not always properly infer the types of some elements of + # discriminated unions of TypedDicts, because of how it handles + # inheritance of TypedDict. If we explicitly extract the items we want + # we can avoid type errors from using `dict.get` later in the method. + prompt_str: Optional[str] = None if decoder_inputs[ + "type"] == "embeds" else decoder_inputs.get("prompt") + prompt_token_ids = decoder_inputs[ + "prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None + prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[ + "type"] == "embeds" else None sampling_params = None pooling_params = None @@ -398,9 +409,10 @@ class Processor: sampling_params = params.clone() # If unset max tokens, then generate up to the max_model_len. if sampling_params.max_tokens is None: - sampling_params.max_tokens = ( - self.model_config.max_model_len - - len(decoder_inputs["prompt_token_ids"])) + seq_len = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds) + sampling_params.max_tokens = \ + self.model_config.max_model_len - seq_len sampling_params.update_from_generation_config( self.generation_config_fields, eos_token_id) if self.tokenizer is not None: @@ -430,9 +442,10 @@ class Processor: identifier=decoder_mm_hashes[modality][idx], mm_position=decoder_mm_positions[modality][idx])) - return decoder_inputs.get("prompt"), EngineCoreRequest( + return prompt_str, EngineCoreRequest( request_id=request_id, - prompt_token_ids=decoder_inputs["prompt_token_ids"], + prompt_token_ids=prompt_token_ids, + prompt_embeds=prompt_embeds, mm_features=mm_features, sampling_params=sampling_params, pooling_params=pooling_params, @@ -461,10 +474,17 @@ class Processor: ): model_config = self.model_config - prompt_ids = prompt_inputs["prompt_token_ids"] + prompt_ids = None if prompt_inputs[ + "type"] == "embeds" else prompt_inputs["prompt_token_ids"] + prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[ + "type"] == "embeds" else None + prompt_len = length_from_prompt_token_ids_or_embeds( + prompt_ids, prompt_embeds) if not prompt_ids: if prompt_type == "encoder" and model_config.is_multimodal_model: pass # Mllama may have empty encoder inputs for text-only data + elif prompt_inputs["type"] == "embeds": + pass # Prompt embeds should not have prompt_ids. else: raise ValueError(f"The {prompt_type} prompt cannot be empty") @@ -472,7 +492,7 @@ class Processor: tokenizer = None else: tokenizer = self.tokenizer - max_input_id = max(prompt_ids, default=0) + max_input_id = max(prompt_ids or [], default=0) # NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while # self.model_config.get_vocab_size() is the model’s vocab size. @@ -490,7 +510,7 @@ class Processor: f"Token id {max_input_id} is out of vocabulary") max_prompt_len = self.model_config.max_model_len - if len(prompt_ids) > max_prompt_len: + if prompt_len > max_prompt_len: if prompt_type == "encoder" and model_config.is_multimodal_model: mm_registry = self.input_preprocessor.mm_registry mm_processor = mm_registry.create_processor( @@ -514,7 +534,7 @@ class Processor: "number of text tokens.") raise ValueError( - f"The {prompt_type} prompt (length {len(prompt_ids)}) is " + f"The {prompt_type} prompt (length {prompt_len}) is " f"longer than the maximum model length of {max_prompt_len}. " f"{suggestion}") diff --git a/vllm/v1/kv_offload/backend.py b/vllm/v1/kv_offload/backend.py new file mode 100644 index 0000000000000..87a74200116bb --- /dev/null +++ b/vllm/v1/kv_offload/backend.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +from abc import ABC, abstractmethod +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import LoadStoreSpec + + +class BlockStatus(ctypes.Structure): + """ + Offloading status for a single block of KV data. + Holds the following information: + + ref_cnt - the current number of transfers using this block as a source. + A value of -1 indicates the block is not yet ready to be read. + load_store_spec - backend-specific information on how to actually + read/write the block. + """ + _fields_ = [("ref_cnt", ctypes.c_int32)] + + def __init__(self): + super().__init__() + # initialize block as "not ready" (ref_cnt = -1) + self.ref_cnt = -1 + + @property + def is_ready(self) -> bool: + """ + Returns whether the block is ready to be read. + """ + return self.ref_cnt >= 0 + + +class Backend(ABC): + """ + An abstract class for allocating and returning specs for writing + KV blocks to some backend. + """ + + def __init__(self, block_size: int, medium: str): + self.block_size = block_size + self.medium = medium + + @abstractmethod + def get_num_free_blocks(self): + """ + Returns the number of current number of blocks that can be allocated. + """ + pass + + @abstractmethod + def allocate_blocks(self, + block_hashes: list[BlockHash]) -> list[BlockStatus]: + """ + Allocate space for writing blocks. + This method assumes there is enough space for allocation. + It is unsafe to use without checking get_num_free_blocks beforehand. + + Args: + block_hashes: the hashes identifying the blocks to be written. + + Returns: + A list of BlockStatus for the allocated blocks. + The ref_cnt of each returned item will be -1, meaning the block + is not yet ready to be read. + """ + pass + + @abstractmethod + def free(self, block: BlockStatus): + """ + Free a previously allocated block. + You should only call this function with blocks returned by + allocate_blocks, and only once per each block. + + Args: + block: The block to be freed. + """ + pass + + def get_load_store_spec(self, block_hashes: Iterable[BlockHash], + blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + """ + Get backend-specific information on how to read/write blocks. + + Args: + block_hashes: the list of block hashes identifying the blocks. + blocks: the list of blocks. + + Returns: + A LoadStoreSpec that can be used by a worker + to read/write the blocks. + """ + raise NotImplementedError diff --git a/vllm/v1/kv_offload/backends/cpu.py b/vllm/v1/kv_offload/backends/cpu.py new file mode 100644 index 0000000000000..eb1123d1d83ac --- /dev/null +++ b/vllm/v1/kv_offload/backends/cpu.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ctypes +from collections.abc import Iterable + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import LoadStoreSpec +from vllm.v1.kv_offload.backend import Backend, BlockStatus +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec + + +class CPUBlockStatus(BlockStatus): + _fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64) + ] # type: ignore + + def __init__(self, block_id: int): + super().__init__() + self.block_id = block_id + + +class CPUBackend(Backend): + + def __init__(self, block_size: int, num_blocks: int): + super().__init__(block_size=block_size, + medium=CPULoadStoreSpec.medium()) + + self.num_blocks: int = num_blocks + self.num_allocated_blocks: int = 0 + self.allocated_blocks_free_list: list[int] = [] + + def get_num_free_blocks(self): + return (len(self.allocated_blocks_free_list) + self.num_blocks - + self.num_allocated_blocks) + + def allocate_blocks(self, + block_hashes: list[BlockHash]) -> list[BlockStatus]: + num_fresh_blocks = min(len(block_hashes), + self.num_blocks - self.num_allocated_blocks) + num_reused_blocks = len(block_hashes) - num_fresh_blocks + assert len(self.allocated_blocks_free_list) >= num_reused_blocks + + # allocate fresh blocks + blocks: list[BlockStatus] = [] + for _ in range(num_fresh_blocks): + blocks.append(CPUBlockStatus(self.num_allocated_blocks)) + self.num_allocated_blocks += 1 + + # allocate reused blocks + for _ in range(num_reused_blocks): + block_id = self.allocated_blocks_free_list.pop() + blocks.append(CPUBlockStatus(block_id)) + + return blocks + + def free(self, block: BlockStatus): + assert isinstance(block, CPUBlockStatus) + self.allocated_blocks_free_list.append(block.block_id) + + def get_load_store_spec(self, block_hashes: Iterable[BlockHash], + blocks: Iterable[BlockStatus]) -> LoadStoreSpec: + return CPULoadStoreSpec([block.block_id for block in blocks]) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py new file mode 100644 index 0000000000000..18d3b1d637b32 --- /dev/null +++ b/vllm/v1/kv_offload/lru_manager.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import OrderedDict +from collections.abc import Iterable +from typing import Optional + +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent, + OffloadingManager, PrepareStoreOutput) +from vllm.v1.kv_offload.backend import Backend, BlockStatus + + +class LRUOffloadingManager(OffloadingManager): + """ + An OffloadingManager with a pluggable backend, which evicts blocks by LRU. + """ + + def __init__(self, backend: Backend, enable_events: bool = False): + self.backend: Backend = backend + # block_hash -> BlockStatus + self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() + self.events: Optional[list[OffloadingEvent]] = \ + [] if enable_events else None + + def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + hit_count = 0 + for block_hash in block_hashes: + block = self.blocks.get(block_hash) + if block is None or not block.is_ready: + break + hit_count += 1 + return hit_count + + def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec: + blocks = [] + for block_hash in block_hashes: + block = self.blocks[block_hash] + assert block.is_ready + block.ref_cnt += 1 + blocks.append(block) + + return self.backend.get_load_store_spec(block_hashes, blocks) + + def touch(self, block_hashes: Iterable[BlockHash]): + for block_hash in reversed(list(block_hashes)): + if self.blocks.get(block_hash): + self.blocks.move_to_end(block_hash) + + def complete_load(self, block_hashes: Iterable[BlockHash]): + for block_hash in block_hashes: + block = self.blocks[block_hash] + assert block.ref_cnt > 0 + block.ref_cnt -= 1 + + def prepare_store( + self, + block_hashes: Iterable[BlockHash]) -> Optional[PrepareStoreOutput]: + # filter out blocks that are already stored + block_hashes_to_store = [ + block_hash for block_hash in block_hashes + if block_hash not in self.blocks + ] + + num_blocks_to_evict = (len(block_hashes_to_store) - + self.backend.get_num_free_blocks()) + + # build list of blocks to evict + to_evict = [] + if num_blocks_to_evict > 0: + for block_hash, block in self.blocks.items(): + if block.ref_cnt == 0: + to_evict.append(block_hash) + num_blocks_to_evict -= 1 + if num_blocks_to_evict == 0: + break + else: + # we could not evict enough blocks + return None + + # evict blocks + for block_hash in to_evict: + self.backend.free(self.blocks.pop(block_hash)) + + if to_evict and self.events is not None: + self.events.append( + OffloadingEvent(block_hashes=to_evict, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=True)) + + blocks = self.backend.allocate_blocks(block_hashes_to_store) + assert len(blocks) == len(block_hashes_to_store) + + for block_hash, block in zip(block_hashes_to_store, blocks): + self.blocks[block_hash] = block + + # build store specs for allocated blocks + store_spec = self.backend.get_load_store_spec(block_hashes_to_store, + blocks) + + return PrepareStoreOutput(block_hashes_to_store=block_hashes_to_store, + store_spec=store_spec, + block_hashes_evicted=to_evict) + + def complete_store(self, + block_hashes: Iterable[BlockHash], + success: bool = True): + stored_block_hashes: list[BlockHash] = [] + if success: + for block_hash in block_hashes: + block = self.blocks[block_hash] + if not block.is_ready: + block.ref_cnt = 0 + stored_block_hashes.append(block_hash) + else: + for block_hash in block_hashes: + block = self.blocks[block_hash] + if not block.is_ready: + self.backend.free(block) + del self.blocks[block_hash] + + if stored_block_hashes and self.events is not None: + self.events.append( + OffloadingEvent(block_hashes=stored_block_hashes, + block_size=self.backend.block_size, + medium=self.backend.medium, + removed=False)) + + def take_events(self) -> Iterable[OffloadingEvent]: + if self.events is not None: + yield from self.events + self.events.clear() diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 145af788d2372..ff10fa00c1cf6 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -7,9 +7,12 @@ from collections.abc import Mapping from functools import partial from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import torch + from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.utils import length_from_prompt_token_ids_or_embeds from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, EngineCoreRequest, FinishReason) from vllm.v1.structured_output.request import StructuredOutputRequest @@ -25,12 +28,13 @@ class Request: def __init__( self, request_id: str, - prompt_token_ids: list[int], + prompt_token_ids: Optional[list[int]], sampling_params: Optional[SamplingParams], pooling_params: Optional[PoolingParams], eos_token_id: Optional[int], client_index: int = 0, arrival_time: Optional[float] = None, + prompt_embeds: Optional[torch.Tensor] = None, mm_features: Optional[list[MultiModalFeatureSpec]] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, @@ -79,9 +83,13 @@ class Request: "sampling_params and pooling_params can't both be unset") self.prompt_token_ids = prompt_token_ids - self.num_prompt_tokens = len(self.prompt_token_ids) + self.prompt_embeds = prompt_embeds + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + prompt_token_ids, prompt_embeds) self._output_token_ids: list[int] = [] - self._all_token_ids: list[int] = self.prompt_token_ids.copy() + self._all_token_ids: list[int] = self.prompt_token_ids.copy( + ) if self.prompt_token_ids is not None else [0 + ] * self.num_prompt_tokens self.num_output_placeholders = 0 # Used in async scheduling. self.spec_token_ids: list[int] = [] self.num_computed_tokens = 0 @@ -123,6 +131,7 @@ class Request: request_id=request.request_id, client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, + prompt_embeds=request.prompt_embeds, mm_features=request.mm_features, sampling_params=request.sampling_params, pooling_params=request.pooling_params, diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index df944873bcaf3..10cad5b530716 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -243,7 +243,7 @@ class AdapterLogitsProcessor(LogitsProcessor): def _new_state( self, params: SamplingParams, - prompt_ids: list[int], + prompt_ids: Optional[list[int]], output_ids: list[int], ) -> Optional[partial[torch.Tensor]]: """Return state representation for new request diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 60f9c0bdb6313..fc655d993cb4c 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -187,7 +187,8 @@ class MinTokensLogitsProcessor(LogitsProcessor): @staticmethod def add_request( - params: SamplingParams, _: list[int], output_tok_ids: list[int] + params: SamplingParams, _: Optional[list[int]], + output_tok_ids: list[int] ) -> Optional[tuple[int, Sequence[int], set[int]]]: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: @@ -234,7 +235,8 @@ class MinTokensLogitsProcessor(LogitsProcessor): def process_dict_updates( req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, list[int], list[int]], Optional[T]] + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], + Optional[T]] ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" diff --git a/vllm/v1/sample/logits_processor/interface.py b/vllm/v1/sample/logits_processor/interface.py index 04027359909a6..a84afc2f347a0 100644 --- a/vllm/v1/sample/logits_processor/interface.py +++ b/vllm/v1/sample/logits_processor/interface.py @@ -26,7 +26,7 @@ RemovedRequest = int # (index, params, prompt_tok_ids, output_tok_ids) tuples for new # requests added to the batch. -AddedRequest = tuple[int, SamplingParams, list[int], list[int]] +AddedRequest = tuple[int, SamplingParams, Optional[list[int]], list[int]] # (index 1, index 2, directionality) tuples representing # one-way moves or two-way swaps of requests in batch diff --git a/vllm/v1/serial_utils.py b/vllm/v1/serial_utils.py index c8375d6f15517..50c1470c67edc 100644 --- a/vllm/v1/serial_utils.py +++ b/vllm/v1/serial_utils.py @@ -174,7 +174,7 @@ class MsgpackEncoder: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: assert self.aux_buffers is not None # view the tensor as a contiguous 1D array of bytes - arr = obj.flatten().contiguous().view(torch.uint8).numpy() + arr = obj.flatten().contiguous().cpu().view(torch.uint8).numpy() if obj.nbytes < self.size_threshold: # Smaller tensors are encoded inline, just like ndarrays. data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6717622efb801..79a392337574f 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -13,7 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams, SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import (BatchUpdateBuilder, @@ -29,7 +29,7 @@ from vllm.v1.worker.block_table import MultiGroupBlockTable class CachedRequestState: req_id: str - prompt_token_ids: list[int] + prompt_token_ids: Optional[list[int]] mm_features: list[MultiModalFeatureSpec] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] @@ -43,9 +43,11 @@ class CachedRequestState: mrope_position_delta: Optional[int] = None lora_request: Optional[LoRARequest] = None + prompt_embeds: Optional[torch.Tensor] = None def __post_init__(self): - self.num_prompt_tokens = len(self.prompt_token_ids) + self.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + self.prompt_token_ids, self.prompt_embeds) @property def num_tokens(self) -> int: @@ -63,6 +65,10 @@ class CachedRequestState: def get_token_id(self, idx: int) -> int: if idx < self.num_prompt_tokens: + if self.prompt_token_ids is None: + raise ValueError( + f"Tried to access token index {idx}, but that token was " + "provided via prompt_embeds, and its ID is unknown.") return self.prompt_token_ids[idx] elif idx - self.num_prompt_tokens < len(self.output_token_ids): return self.output_token_ids[idx - self.num_prompt_tokens] @@ -109,6 +115,14 @@ class InputBatch: pin_memory=False, ) self.token_ids_cpu = self.token_ids_cpu_tensor.numpy() + self.is_token_ids = torch.zeros((max_num_reqs, max_model_len), + device="cpu", + dtype=bool, + pin_memory=False) + # Store prompt embeddings per request to avoid OOM from large upfront + # allocation if max_model_len is big. + # Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size) + self.req_prompt_embeds: dict[int, torch.Tensor] = {} self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) @@ -310,15 +324,23 @@ class InputBatch: self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) self.num_prompt_tokens[req_index] = num_prompt_tokens - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids start_idx = num_prompt_tokens end_idx = start_idx + len(request.output_token_ids) + if request.prompt_token_ids is not None: + self.token_ids_cpu[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + self.is_token_ids[req_index, :num_prompt_tokens] = True + else: + self.is_token_ids[req_index, :num_prompt_tokens] = False + if request.prompt_embeds is not None: + self.req_prompt_embeds[req_index] = request.prompt_embeds self.token_ids_cpu[req_index, start_idx:end_idx] = request.output_token_ids - # Number of token ids in token_ids_cpu. + self.is_token_ids[req_index, start_idx:end_idx] = True + # Number of token ids in prompt (token_ids_cpu or prompt_embeds). # NOTE(woosuk): This may include spec decode tokens. self.num_tokens[req_index] = request.num_tokens # Number of tokens without spec decode tokens. @@ -503,6 +525,20 @@ class InputBatch: self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] self.token_ids_cpu[i2, ...] = tmp + self.is_token_ids[[i1, i2], ...] = self.is_token_ids[[i2, i1], ...] + + # Swap prompt embeddings if they exist + embeds_i1 = self.req_prompt_embeds.get(i1) + embeds_i2 = self.req_prompt_embeds.get(i2) + if embeds_i1 is not None: + self.req_prompt_embeds[i2] = embeds_i1 + else: + self.req_prompt_embeds.pop(i2, None) + if embeds_i2 is not None: + self.req_prompt_embeds[i1] = embeds_i2 + else: + self.req_prompt_embeds.pop(i1, None) + self.block_table.swap_row(i1, i2) self.request_lora_mapping[i1], self.request_lora_mapping[i2] = \ @@ -592,6 +628,11 @@ class InputBatch: num_tokens = self.num_tokens[last_req_index] self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[ last_req_index, :num_tokens] + self.is_token_ids[empty_index, :num_tokens] = self.is_token_ids[ + last_req_index, :num_tokens] + if last_req_index in self.req_prompt_embeds: + self.req_prompt_embeds[ + empty_index] = self.req_prompt_embeds.pop(last_req_index) self.num_tokens[empty_index] = num_tokens self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[ last_req_index] diff --git a/vllm/v1/worker/tpu_input_batch.py b/vllm/v1/worker/tpu_input_batch.py index dfa54d0ad83b6..4cd0ac352de0f 100644 --- a/vllm/v1/worker/tpu_input_batch.py +++ b/vllm/v1/worker/tpu_input_batch.py @@ -9,7 +9,7 @@ import torch from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingType -from vllm.utils import swap_dict_values +from vllm.utils import length_from_prompt_token_ids_or_embeds, swap_dict_values from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.block_table import MultiGroupBlockTable from vllm.v1.worker.gpu_input_batch import CachedRequestState @@ -213,7 +213,9 @@ class InputBatch: self.req_id_to_index[req_id] = req_index # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) + num_prompt_tokens = length_from_prompt_token_ids_or_embeds( + request.prompt_token_ids, request.prompt_embeds) + # TODO: copy prompt_embeds self.num_prompt_tokens[req_index] = num_prompt_tokens self.token_ids_cpu[ req_index, :num_prompt_tokens] = request.prompt_token_ids diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 43f12912707f1..01a8e5c3f0dba 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -387,6 +387,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.requests[req_id] = CachedRequestState( req_id=req_id, prompt_token_ids=new_req_data.prompt_token_ids, + prompt_embeds=new_req_data.prompt_embeds, mm_features=new_req_data.mm_features, sampling_params=sampling_params, pooling_params=None,