mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 19:07:08 +08:00
[Misc] Simplify PoolerOutput and move to v1/outputs (#25629)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
a676e668ee
commit
755ed7b05b
@ -15,10 +15,10 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import SupportedTask
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
from vllm.v1.outputs import SamplerOutput
|
from vllm.v1.outputs import PoolerOutput, SamplerOutput
|
||||||
from vllm.worker.worker_base import WorkerBase
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|||||||
@ -16,9 +16,9 @@ from vllm.config import ModelConfig, PoolerConfig, get_current_vllm_config
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.adapters import _load_st_projector
|
from vllm.model_executor.models.adapters import _load_st_projector
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sequence import PoolerOutput, PoolingSequenceGroupOutput
|
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.utils import current_stream, resolve_obj_by_qualname
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
|
from vllm.v1.outputs import PoolerOutput
|
||||||
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingCursor, PoolingMetadata
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -190,19 +190,6 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|||||||
return PoolerClassify()
|
return PoolerClassify()
|
||||||
|
|
||||||
|
|
||||||
def build_output(
|
|
||||||
all_data: Union[torch.Tensor, list[torch.Tensor]], ) -> PoolerOutput:
|
|
||||||
# Pooling models D2H & synchronize occurs here
|
|
||||||
if isinstance(all_data, list):
|
|
||||||
all_data = [d.to("cpu", non_blocking=True) for d in all_data]
|
|
||||||
else:
|
|
||||||
all_data = all_data.to("cpu", non_blocking=True)
|
|
||||||
current_stream().synchronize()
|
|
||||||
|
|
||||||
all_outputs = [PoolingSequenceGroupOutput(data) for data in all_data]
|
|
||||||
return PoolerOutput(outputs=all_outputs)
|
|
||||||
|
|
||||||
|
|
||||||
class PoolingMethod(nn.Module, ABC):
|
class PoolingMethod(nn.Module, ABC):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -556,7 +543,7 @@ class SimplePooler(Pooler):
|
|||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
return build_output(pooled_data)
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
class StepPooler(Pooler):
|
class StepPooler(Pooler):
|
||||||
@ -607,7 +594,7 @@ class StepPooler(Pooler):
|
|||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
pooled_data = self.extract_states(hidden_states, pooling_metadata)
|
||||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
return build_output(pooled_data)
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
class ClassifierPooler(Pooler):
|
class ClassifierPooler(Pooler):
|
||||||
@ -678,7 +665,7 @@ class ClassifierPooler(Pooler):
|
|||||||
]
|
]
|
||||||
|
|
||||||
# scores shape: [batchsize, num_labels]
|
# scores shape: [batchsize, num_labels]
|
||||||
return build_output(scores)
|
return scores
|
||||||
|
|
||||||
|
|
||||||
class DispatchPooler(Pooler):
|
class DispatchPooler(Pooler):
|
||||||
@ -708,7 +695,7 @@ class DispatchPooler(Pooler):
|
|||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
poolers_by_task = self.poolers_by_task
|
poolers_by_task = self.poolers_by_task
|
||||||
|
|
||||||
outputs = list[PoolingSequenceGroupOutput]()
|
outputs = list[torch.Tensor]()
|
||||||
offset = 0
|
offset = 0
|
||||||
for task, group in groupby(get_tasks(pooling_metadata)):
|
for task, group in groupby(get_tasks(pooling_metadata)):
|
||||||
if not (pooler := poolers_by_task.get(task)):
|
if not (pooler := poolers_by_task.get(task)):
|
||||||
@ -722,10 +709,10 @@ class DispatchPooler(Pooler):
|
|||||||
pooling_metadata[offset:offset + num_items],
|
pooling_metadata[offset:offset + num_items],
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs.extend(group_output.outputs)
|
outputs.extend(group_output)
|
||||||
offset += num_items
|
offset += num_items
|
||||||
|
|
||||||
return PoolerOutput(outputs)
|
return outputs
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"supported_task={self.get_supported_tasks()}"
|
s = f"supported_task={self.get_supported_tasks()}"
|
||||||
|
|||||||
@ -12,12 +12,12 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
from vllm.model_executor.layers.pooler import (DispatchPooler, Pooler,
|
||||||
PoolerHead, PoolerNormalize,
|
PoolerHead, PoolerNormalize,
|
||||||
PoolingParamsUpdate,
|
PoolingParamsUpdate,
|
||||||
build_output, get_prompt_lens,
|
get_prompt_lens,
|
||||||
get_prompt_token_ids)
|
get_prompt_token_ids)
|
||||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||||
from vllm.sequence import PoolerOutput
|
|
||||||
from vllm.tasks import PoolingTask
|
from vllm.tasks import PoolingTask
|
||||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||||
|
from vllm.v1.outputs import PoolerOutput
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
|
|
||||||
from .interfaces_base import default_pooling_type
|
from .interfaces_base import default_pooling_type
|
||||||
@ -212,7 +212,7 @@ class GritLMPooler(Pooler):
|
|||||||
) -> PoolerOutput:
|
) -> PoolerOutput:
|
||||||
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
pooled_data = self.pooling(hidden_states, pooling_metadata)
|
||||||
pooled_data = self.head(pooled_data, pooling_metadata)
|
pooled_data = self.head(pooled_data, pooling_metadata)
|
||||||
return build_output(pooled_data)
|
return pooled_data
|
||||||
|
|
||||||
|
|
||||||
@default_pooling_type("MEAN")
|
@default_pooling_type("MEAN")
|
||||||
|
|||||||
@ -11,7 +11,6 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||||
KVConnectorOutput)
|
KVConnectorOutput)
|
||||||
else:
|
else:
|
||||||
LoRARequest = Any
|
|
||||||
KVConnectorOutput = Any
|
KVConnectorOutput = Any
|
||||||
|
|
||||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||||
@ -48,29 +47,6 @@ class RequestMetrics:
|
|||||||
model_execute_time: Optional[float] = None
|
model_execute_time: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class PoolingSequenceGroupOutput(
|
|
||||||
msgspec.Struct,
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
array_like=True, # type: ignore[call-arg]
|
|
||||||
):
|
|
||||||
"""The model output associated with a pooling sequence group."""
|
|
||||||
# Annotated as Any to be compatible with msgspec
|
|
||||||
# The actual type is in SequenceGroup.pooled_data
|
|
||||||
data: Any
|
|
||||||
|
|
||||||
def get_data_nbytes(self) -> int:
|
|
||||||
data: torch.Tensor = self.data
|
|
||||||
return data.nbytes
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"PoolingSequenceGroupOutput(data={self.data}"
|
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
|
||||||
if not isinstance(other, PoolingSequenceGroupOutput):
|
|
||||||
raise NotImplementedError()
|
|
||||||
return self.data == other.data
|
|
||||||
|
|
||||||
|
|
||||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||||
@dataclass
|
@dataclass
|
||||||
class IntermediateTensors:
|
class IntermediateTensors:
|
||||||
@ -119,30 +95,6 @@ class IntermediateTensors:
|
|||||||
return f"IntermediateTensors(tensors={self.tensors})"
|
return f"IntermediateTensors(tensors={self.tensors})"
|
||||||
|
|
||||||
|
|
||||||
class PoolerOutput(
|
|
||||||
msgspec.Struct,
|
|
||||||
omit_defaults=True, # type: ignore[call-arg]
|
|
||||||
array_like=True): # type: ignore[call-arg]
|
|
||||||
"""The output from a pooling operation in the pooling model."""
|
|
||||||
outputs: list[PoolingSequenceGroupOutput]
|
|
||||||
|
|
||||||
def get_data_nbytes(self) -> int:
|
|
||||||
return sum(o.get_data_nbytes() for o in self.outputs)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> PoolingSequenceGroupOutput:
|
|
||||||
return self.outputs[idx]
|
|
||||||
|
|
||||||
def __setitem__(self, idx: int, value: PoolingSequenceGroupOutput):
|
|
||||||
self.outputs[idx] = value
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.outputs)
|
|
||||||
|
|
||||||
def __eq__(self, other: object):
|
|
||||||
return isinstance(other,
|
|
||||||
self.__class__) and self.outputs == other.outputs
|
|
||||||
|
|
||||||
|
|
||||||
class ExecuteModelRequest(
|
class ExecuteModelRequest(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
array_like=True, # type: ignore[call-arg]
|
array_like=True, # type: ignore[call-arg]
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
from typing import TYPE_CHECKING, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -65,6 +65,11 @@ class LogprobsTensors(NamedTuple):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# [num_reqs, <dynamic>]
|
||||||
|
# The shape of each element depends on the pooler used
|
||||||
|
PoolerOutput = Union[torch.Tensor, list[torch.Tensor]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplerOutput:
|
class SamplerOutput:
|
||||||
|
|
||||||
|
|||||||
@ -52,13 +52,14 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem,
|
|||||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
|
GiB_bytes, cdiv, check_use_alibi, get_dtype_size,
|
||||||
is_pin_memory_available,
|
is_pin_memory_available,
|
||||||
length_from_prompt_token_ids_or_embeds, round_up,
|
length_from_prompt_token_ids_or_embeds, round_up,
|
||||||
supports_dynamo)
|
supports_dynamo)
|
||||||
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@ -79,7 +80,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||||
ModelRunnerOutput, SamplerOutput)
|
ModelRunnerOutput, PoolerOutput, SamplerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -1823,15 +1824,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
device=hidden_states.device)
|
device=hidden_states.device)
|
||||||
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
seq_lens_cpu = self.seq_lens.cpu[:self.input_batch.num_reqs]
|
||||||
|
|
||||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
model = cast(VllmModelForPooling, self.model)
|
||||||
raw_pooler_output = self.model.pooler(
|
raw_pooler_output: PoolerOutput = model.pooler(
|
||||||
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
hidden_states=hidden_states,
|
||||||
|
pooling_metadata=pooling_metadata,
|
||||||
|
)
|
||||||
|
raw_pooler_output = json_map_leaves(
|
||||||
|
lambda x: x.to("cpu", non_blocking=True),
|
||||||
|
raw_pooler_output,
|
||||||
|
)
|
||||||
|
self._sync_device()
|
||||||
|
|
||||||
pooler_output: list[Optional[torch.Tensor]] = []
|
pooler_output: list[Optional[torch.Tensor]] = []
|
||||||
for raw_output, seq_len, prompt_len in zip(
|
for raw_output, seq_len, prompt_len in zip(
|
||||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||||
|
|
||||||
output = raw_output.data if seq_len == prompt_len else None
|
output = raw_output if seq_len == prompt_len else None
|
||||||
pooler_output.append(output)
|
pooler_output.append(output)
|
||||||
|
|
||||||
return ModelRunnerOutput(
|
return ModelRunnerOutput(
|
||||||
@ -3233,7 +3241,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
for task in self.get_supported_pooling_tasks():
|
for task in self.get_supported_pooling_tasks():
|
||||||
# Run a full batch with each task to ensure none of them OOMs
|
# Run a full batch with each task to ensure none of them OOMs
|
||||||
output = self._dummy_pooler_run_task(hidden_states, task)
|
output = self._dummy_pooler_run_task(hidden_states, task)
|
||||||
output_size[task] = output.get_data_nbytes()
|
output_size[task] = sum(o.nbytes for o in output)
|
||||||
del output # Allow GC
|
del output # Allow GC
|
||||||
|
|
||||||
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
max_task = max(output_size.items(), key=lambda x: x[1])[0]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user