[V0 Deprecation] Remove from_seq_group methods (#25330)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-20 21:10:48 -07:00 committed by GitHub
parent 035fd2bd2c
commit 12dbd834cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 315 deletions

View File

@ -2,14 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Generic, NamedTuple, TypeVar
if TYPE_CHECKING:
from vllm.sequence import SequenceGroupMetadata
from .inputs import MultiModalKwargs, PlaceholderRange
from typing import Generic, NamedTuple, TypeVar
_T = TypeVar("_T")
@ -53,120 +47,6 @@ class MultiModalPlaceholderMap:
self.dest_ranges = []
self.dest_len = 0
@classmethod
def from_seq_group(
cls, seq_group: "SequenceGroupMetadata", positions: range
) -> tuple[MultiModalKwargs, dict[str, "MultiModalPlaceholderMap"]]:
"""
Returns the multi-modal items that intersect with the portion of a
prompt (``seq_group``) represented by ``positions``, as well as a
``MultiModalPlaceholderMap`` that relates the multi-modal embedding
vectors to their corresponding placeholders.
Examples:
```
Prompt: |AAAA BBBB What's in these images?|
Positions: |.................................|
images = [A, B]
src_ranges = [(0, 4), (4, 8)]
dest_ranges = [(0, 4), (5, 9)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ..... |
images = [A, B]
src_ranges = [(2, 4), (4, 6)]
dest_ranges = [(0, 2), (3, 5)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | ......... |
images = [B]
src_ranges = [(0, 4)]
dest_ranges = [(0, 4)]
Prompt: |AAAA BBBB What's in these images?|
Positions: | .......................|
images = []
src_ranges = []
dest_ranges = []
```
"""
seq_mm_data = seq_group.multi_modal_data
seq_mm_placeholders = seq_group.multi_modal_placeholders
if not seq_mm_data or not seq_mm_placeholders:
return MultiModalKwargs(), {}
placeholder_maps = dict[str, MultiModalPlaceholderMap]()
for modality, placeholders in seq_mm_placeholders.items():
placeholder_map = MultiModalPlaceholderMap()
if positions:
placeholder_map.append_items_from_seq_group(
positions,
# Dummy, since we don't care about intersecting items
[None] * len(placeholders),
placeholders,
)
placeholder_maps[modality] = placeholder_map
return seq_mm_data, placeholder_maps
def append_items_from_seq_group(
self,
positions: range,
multi_modal_items: list[_T],
multi_modal_placeholders: Sequence[PlaceholderRange],
) -> list[_T]:
"""
Adds the multi-modal items that intersect ```positions`` to this
placeholder map and returns the intersecting items.
"""
intersecting_items = []
if len(multi_modal_items) != len(multi_modal_placeholders):
raise ValueError(
"Multi-modal placeholders and items must have the same length."
)
for placeholder_dict, mm_item in zip(multi_modal_placeholders,
multi_modal_items):
placeholder = range(
placeholder_dict.offset,
placeholder_dict.offset + placeholder_dict.length,
)
intersection = range(
max(positions.start, placeholder.start),
min(positions.stop, placeholder.stop),
)
if not intersection:
# Skip this multi-modal item.
continue
token_embedding_range = range(
intersection.start - positions.start,
intersection.stop - positions.start,
)
multimodal_embedding_range = range(
intersection.start - placeholder.start + self.src_len,
intersection.stop - placeholder.start + self.src_len,
)
intersecting_items.append(mm_item)
self.dest_ranges.append(token_embedding_range)
self.src_ranges.append(multimodal_embedding_range)
self.src_len += len(placeholder)
self.dest_len += len(positions)
return intersecting_items
def extend(self, other: "MultiModalPlaceholderMap"):
"""
Adds the placeholders from another ``MultiModalPlaceholderMap`` to this

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
from collections.abc import MutableSequence
from collections.abc import Sequence as GenericSequence
from dataclasses import dataclass
@ -14,9 +13,7 @@ from vllm.logger import init_logger
from vllm.logprobs import PromptLogprobs, SampleLogprobs
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sequence import (RequestMetrics, SequenceGroup, SequenceGroupBase,
SequenceStatus)
from vllm.sequence import RequestMetrics
logger = init_logger(__name__)
@ -171,170 +168,6 @@ class RequestOutput:
else:
self.outputs.append(next_completion)
@classmethod
def from_seq_group(
cls, seq_group: SequenceGroup, use_cache: bool,
seq_id_to_seq_group: dict[str, SequenceGroupBase]
) -> Optional["RequestOutput"]:
finished = seq_group.is_finished()
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished:
group.finish_seq(seq_group)
if assembled_seq_group is None:
return None
# clear finished seq in seq_id_to_seq_group
if len(group.to_be_finished) == 0:
for sub_request_id in list(group.seq_id_to_index.keys()):
if sub_request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[sub_request_id]
return cls.from_seq_group(assembled_seq_group, use_cache,
seq_id_to_seq_group)
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None
# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)
top_n_seqs = seq_group.get_seqs()
# Create the outputs.
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
include_logprobs = sampling_params.logprobs is not None
text_buffer_length = sampling_params.output_text_buffer_length
delta = sampling_params.output_kind == RequestOutputKind.DELTA
outputs = []
include_prompt = True
# num_cached_tokens should be the same for all the sequences
num_cached_tokens = None
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)
output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)
num_cached_tokens = seq.data.get_num_cached_tokens()
output_logprobs = seq.output_logprobs if include_logprobs else None
if delta:
# Slice logprobs delta if applicable
if output_logprobs:
# num_output_tokens can be 0 when n > 1 and request finishes
# before the others
if num_output_tokens > 0:
output_logprobs = output_logprobs[-num_output_tokens:]
else:
output_logprobs = None
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > num_output_tokens:
include_prompt = False
if use_cache:
# Get cached output object
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]
# Init cached output object
assert output.index == i
output.text = output_text
if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids
output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason
else:
output = CompletionOutput(
top_n_seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason)
outputs.append(output)
# Every sequence in the sequence group should have the same prompt.
if include_prompt:
prompt = seq_group.prompt
prompt_token_ids = seq_group.prompt_token_ids
encoder_prompt = seq_group.encoder_prompt
encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
else:
prompt = None
prompt_token_ids = None
encoder_prompt = None
encoder_prompt_token_ids = None
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
init_kwargs = {
"request_id": seq_group.request_id,
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"prompt_logprobs": prompt_logprobs,
"outputs": outputs,
"finished": finished,
"metrics": seq_group.metrics,
"lora_request": seq_group.lora_request,
"encoder_prompt": encoder_prompt,
"encoder_prompt_token_ids": encoder_prompt_token_ids,
"num_cached_tokens": num_cached_tokens,
"multi_modal_placeholders": seq_group.multi_modal_placeholders
}
if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(**init_kwargs) # type: ignore
else:
request_output = cls(**init_kwargs) # type: ignore
return request_output
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
@ -371,19 +204,6 @@ class PoolingRequestOutput(Generic[_O]):
self.finished = finished
self.outputs = outputs
@staticmethod
def from_seq_group(seq_group: SequenceGroup) -> "PoolingRequestOutput":
pooled_data = seq_group.pooled_data
assert pooled_data is not None
data = pooled_data.to(dtype=torch.float32, device="cpu")
output = PoolingOutput(data)
prompt_token_ids = seq_group.prompt_token_ids
finished = seq_group.is_finished()
return PoolingRequestOutput(seq_group.request_id, output,
prompt_token_ids, finished)
def __repr__(self):
return (f"{type(self).__name__}(request_id={self.request_id!r}, "
f"outputs={self.outputs!r}, "
@ -391,19 +211,6 @@ class PoolingRequestOutput(Generic[_O]):
f"finished={self.finished})")
class RequestOutputFactory:
@staticmethod
def create(seq_group: SequenceGroup,
seq_id_to_seq_group: dict[str, SequenceGroupBase],
use_cache: bool = False):
if seq_group.pooled_data is not None:
return PoolingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group, use_cache,
seq_id_to_seq_group)
@dataclass
class EmbeddingOutput:
"""The output data of one embedding output of a request.