diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index ef8f1b2e17b47..e0edb3e883ed6 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,14 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod -from collections.abc import Sequence from pathlib import Path -from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar - -if TYPE_CHECKING: - from vllm.sequence import SequenceGroupMetadata - -from .inputs import MultiModalKwargs, PlaceholderRange +from typing import Generic, NamedTuple, TypeVar _T = TypeVar("_T") @@ -53,120 +47,6 @@ class MultiModalPlaceholderMap: self.dest_ranges = [] self.dest_len = 0 - @classmethod - def from_seq_group( - cls, seq_group: "SequenceGroupMetadata", positions: range - ) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]: - """ - Returns the multi-modal items that intersect with the portion of a - prompt (``seq_group``) represented by ``positions``, as well as a - ``MultiModalPlaceholderMap`` that relates the multi-modal embedding - vectors to their corresponding placeholders. - - Examples: - - ``` - Prompt: |AAAA BBBB What's in these images?| - Positions: |.................................| - - images = [A, B] - src_ranges = [(0, 4), (4, 8)] - dest_ranges = [(0, 4), (5, 9)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ..... | - - images = [A, B] - src_ranges = [(2, 4), (4, 6)] - dest_ranges = [(0, 2), (3, 5)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | ......... | - - images = [B] - src_ranges = [(0, 4)] - dest_ranges = [(0, 4)] - - Prompt: |AAAA BBBB What's in these images?| - Positions: | .......................| - - images = [] - src_ranges = [] - dest_ranges = [] - ``` - """ - seq_mm_data = seq_group.multi_modal_data - seq_mm_placeholders = seq_group.multi_modal_placeholders - - if not seq_mm_data or not seq_mm_placeholders: - return MultiModalKwargs(), {} - - placeholder_maps = dict[str, MultiModalPlaceholderMap]() - - for modality, placeholders in seq_mm_placeholders.items(): - placeholder_map = MultiModalPlaceholderMap() - - if positions: - placeholder_map.append_items_from_seq_group( - positions, - # Dummy, since we don't care about intersecting items - [None] * len(placeholders), - placeholders, - ) - - placeholder_maps[modality] = placeholder_map - - return seq_mm_data, placeholder_maps - - def append_items_from_seq_group( - self, - positions: range, - multi_modal_items: list[_T], - multi_modal_placeholders: Sequence[PlaceholderRange], - ) -> list[_T]: - """ - Adds the multi-modal items that intersect ```positions`` to this - placeholder map and returns the intersecting items. - """ - intersecting_items = [] - - if len(multi_modal_items) != len(multi_modal_placeholders): - raise ValueError( - "Multi-modal placeholders and items must have the same length." - ) - for placeholder_dict, mm_item in zip(multi_modal_placeholders, - multi_modal_items): - placeholder = range( - placeholder_dict.offset, - placeholder_dict.offset + placeholder_dict.length, - ) - intersection = range( - max(positions.start, placeholder.start), - min(positions.stop, placeholder.stop), - ) - - if not intersection: - # Skip this multi-modal item. - continue - - token_embedding_range = range( - intersection.start - positions.start, - intersection.stop - positions.start, - ) - - multimodal_embedding_range = range( - intersection.start - placeholder.start + self.src_len, - intersection.stop - placeholder.start + self.src_len, - ) - - intersecting_items.append(mm_item) - self.dest_ranges.append(token_embedding_range) - self.src_ranges.append(multimodal_embedding_range) - self.src_len += len(placeholder) - - self.dest_len += len(positions) - return intersecting_items - def extend(self, other: "MultiModalPlaceholderMap"): """ Adds the placeholders from another ``MultiModalPlaceholderMap`` to this diff --git a/vllm/outputs.py b/vllm/outputs.py index 64bcfd472f2ad..4d8206bb2d830 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time from collections.abc import MutableSequence from collections.abc import Sequence as GenericSequence from dataclasses import dataclass @@ -14,9 +13,7 @@ from vllm.logger import init_logger from vllm.logprobs import PromptLogprobs, SampleLogprobs from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalPlaceholderDict -from vllm.sampling_params import RequestOutputKind -from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase, - SequenceStatus) +from vllm.sequence import RequestMetrics logger = init_logger(__name__) @@ -171,170 +168,6 @@ class RequestOutput: else: self.outputs.append(next_completion) - @classmethod - def from_seq_group( - cls, seq_group: SequenceGroup, use_cache: bool, - seq_id_to_seq_group: dict[str, SequenceGroupBase] - ) -> Optional["RequestOutput"]: - finished = seq_group.is_finished() - - if seq_group.request_id in seq_id_to_seq_group: - group: SequenceGroupBase = seq_id_to_seq_group[ - seq_group.request_id] - assembled_seq_group = group.maybe_assemble_group(seq_group) - if finished: - group.finish_seq(seq_group) - if assembled_seq_group is None: - return None - - # clear finished seq in seq_id_to_seq_group - if len(group.to_be_finished) == 0: - for sub_request_id in list(group.seq_id_to_index.keys()): - if sub_request_id in seq_id_to_seq_group: - del seq_id_to_seq_group[sub_request_id] - - return cls.from_seq_group(assembled_seq_group, use_cache, - seq_id_to_seq_group) - - sampling_params = seq_group.sampling_params - if sampling_params is None: - raise ValueError( - "Sampling parameters are missing for a CompletionRequest.") - - if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( - not finished): - return None - - # Init cache (if needed) - if use_cache and seq_group.cached_request_output is None: - seq_group.cached_request_output = RequestOutput( # type: ignore - request_id="", - prompt=None, - prompt_token_ids=[], - prompt_logprobs=None, - outputs=[], - finished=False) - - top_n_seqs = seq_group.get_seqs() - - # Create the outputs. - # NOTE: We need omit logprobs here explicitly because the sequence - # always has the logprobs of the sampled tokens even if the - # logprobs are not requested. - include_logprobs = sampling_params.logprobs is not None - text_buffer_length = sampling_params.output_text_buffer_length - delta = sampling_params.output_kind == RequestOutputKind.DELTA - - outputs = [] - include_prompt = True - # num_cached_tokens should be the same for all the sequences - num_cached_tokens = None - for i, seq in enumerate(top_n_seqs): - output_text = seq.get_output_text_to_return( - text_buffer_length, delta) - - output_token_ids = seq.get_output_token_ids_to_return(delta) - num_output_tokens = 1 if isinstance(output_token_ids, - int) else len(output_token_ids) - num_cached_tokens = seq.data.get_num_cached_tokens() - - output_logprobs = seq.output_logprobs if include_logprobs else None - - if delta: - # Slice logprobs delta if applicable - if output_logprobs: - # num_output_tokens can be 0 when n > 1 and request finishes - # before the others - if num_output_tokens > 0: - output_logprobs = output_logprobs[-num_output_tokens:] - else: - output_logprobs = None - # Don't include prompt if this is after the first output - # containing decode token ids - if include_prompt and seq.get_output_len() > num_output_tokens: - include_prompt = False - - if use_cache: - # Get cached output object - cached_outputs = seq_group.cached_request_output.outputs # type: ignore - if i >= len(cached_outputs): - cached_outputs.append( - CompletionOutput(index=i, - text="", - token_ids=[], - cumulative_logprob=None, - logprobs=None, - finish_reason=None, - stop_reason=None)) - output = cached_outputs[i] - - # Init cached output object - assert output.index == i - output.text = output_text - - if isinstance(output_token_ids, int): - output.token_ids.clear() - output.token_ids.append(output_token_ids) - else: - output.token_ids = output_token_ids - - output.cumulative_logprob = seq.get_cumulative_logprob() \ - if include_logprobs else None - output.logprobs = output_logprobs - output.finish_reason = SequenceStatus.get_finished_reason( - seq.status) - output.stop_reason = seq.stop_reason - - else: - output = CompletionOutput( - top_n_seqs.index(seq), output_text, [output_token_ids] - if isinstance(output_token_ids, int) else output_token_ids, - seq.get_cumulative_logprob() if include_logprobs else None, - output_logprobs, - SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason) - - outputs.append(output) - - # Every sequence in the sequence group should have the same prompt. - if include_prompt: - prompt = seq_group.prompt - prompt_token_ids = seq_group.prompt_token_ids - encoder_prompt = seq_group.encoder_prompt - encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids - prompt_logprobs = seq_group.prompt_logprobs - else: - prompt = None - prompt_token_ids = None - encoder_prompt = None - encoder_prompt_token_ids = None - prompt_logprobs = None - finished_time = time.time() if finished else None - seq_group.set_finished_time(finished_time) - - init_kwargs = { - "request_id": seq_group.request_id, - "prompt": prompt, - "prompt_token_ids": prompt_token_ids, - "prompt_logprobs": prompt_logprobs, - "outputs": outputs, - "finished": finished, - "metrics": seq_group.metrics, - "lora_request": seq_group.lora_request, - "encoder_prompt": encoder_prompt, - "encoder_prompt_token_ids": encoder_prompt_token_ids, - "num_cached_tokens": num_cached_tokens, - "multi_modal_placeholders": seq_group.multi_modal_placeholders - } - - if use_cache: - request_output = seq_group.cached_request_output - request_output.__init__(**init_kwargs) # type: ignore - else: - request_output = cls(**init_kwargs) # type: ignore - - return request_output - def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " @@ -371,19 +204,6 @@ class PoolingRequestOutput(Generic[_O]): self.finished = finished self.outputs = outputs - @staticmethod - def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput": - pooled_data = seq_group.pooled_data - assert pooled_data is not None - - data = pooled_data.to(dtype=torch.float32, device="cpu") - output = PoolingOutput(data) - prompt_token_ids = seq_group.prompt_token_ids - finished = seq_group.is_finished() - - return PoolingRequestOutput(seq_group.request_id, output, - prompt_token_ids, finished) - def __repr__(self): return (f"{type(self).__name__}(request_id={self.request_id!r}, " f"outputs={self.outputs!r}, " @@ -391,19 +211,6 @@ class PoolingRequestOutput(Generic[_O]): f"finished={self.finished})") -class RequestOutputFactory: - - @staticmethod - def create(seq_group: SequenceGroup, - seq_id_to_seq_group: dict[str, SequenceGroupBase], - use_cache: bool = False): - if seq_group.pooled_data is not None: - return PoolingRequestOutput.from_seq_group(seq_group) - else: - return RequestOutput.from_seq_group(seq_group, use_cache, - seq_id_to_seq_group) - - @dataclass class EmbeddingOutput: """The output data of one embedding output of a request.