mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 01:35:01 +08:00
[Frontend][4/N] Improve all pooling task | Add plugin pooling task (#26973)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
parent
fe2016de2d
commit
3fa2c12185
@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput")
|
||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
# We cannot guarantee outputs are returned in the same order they were
|
||||
# fed to vLLM.
|
||||
# Let's sort them by id before post_processing
|
||||
sorted_output = sorted(
|
||||
[(i, item) async for i, item in model_output], key=lambda output: output[0]
|
||||
)
|
||||
collected_output = [output[1] for output in sorted_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_or_generate_params(
|
||||
self, params: SamplingParams | PoolingParams | None = None
|
||||
) -> SamplingParams | PoolingParams:
|
||||
return params or PoolingParams()
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput
|
||||
@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
|
||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||
|
||||
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
||||
|
||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||
|
||||
## Using an IO Processor plugin
|
||||
|
||||
|
||||
@ -64,7 +64,7 @@ class PrithviMAE:
|
||||
}
|
||||
|
||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
||||
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
|
||||
|
||||
return outputs[0].outputs.data
|
||||
|
||||
|
||||
@ -6,14 +6,14 @@ import os
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
# This example shows how to perform an offline inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Requirement - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
# Requirements:
|
||||
# - install TerraTorch v1.1 (or later):
|
||||
# pip install terratorch>=v1.1
|
||||
|
||||
|
||||
def main():
|
||||
@ -36,16 +36,12 @@ def main():
|
||||
# to avoid the model going OOM.
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
io_processor_plugin="terratorch_segmentation",
|
||||
model_impl="terratorch",
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="token_classify", activation=False)
|
||||
pooler_output = llm.encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
print(output)
|
||||
|
||||
@ -11,14 +11,14 @@ import requests
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Requirements :
|
||||
# - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
# - install TerraTorch v1.1 (or later):
|
||||
# pip install terratorch>=v1.1
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --model-impl terratorch
|
||||
# --task embed --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin prithvi_to_tiff
|
||||
# --io-processor-plugin terratorch_segmentation
|
||||
# --enable-mm-embeds
|
||||
|
||||
|
||||
@ -35,7 +35,6 @@ def main():
|
||||
},
|
||||
"priority": 0,
|
||||
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
"softmax": False,
|
||||
}
|
||||
|
||||
ret = requests.post(server_endpoint, json=request_payload_url)
|
||||
|
||||
@ -40,7 +40,7 @@ def _run_test(
|
||||
max_num_seqs=32,
|
||||
default_torch_num_threads=1,
|
||||
) as vllm_model:
|
||||
vllm_model.llm.encode(prompt, pooling_task="token_classify")
|
||||
vllm_model.llm.encode(prompt, pooling_task="plugin")
|
||||
|
||||
|
||||
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||
|
||||
@ -368,9 +368,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
out_format = "b64_json"
|
||||
|
||||
for output in model_output:
|
||||
y_hat = output.outputs.data.argmax(dim=1)
|
||||
y_hat = output.outputs.data.argmax(dim=0)
|
||||
pred = torch.nn.functional.interpolate(
|
||||
y_hat.unsqueeze(1).float(),
|
||||
y_hat[None, None, ...].float(),
|
||||
size=self.img_size,
|
||||
mode="nearest",
|
||||
)
|
||||
|
||||
@ -9,7 +9,6 @@ from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||
|
||||
@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
model_impl="terratorch",
|
||||
io_processor_plugin="prithvi_to_tiff",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
|
||||
)
|
||||
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
|
||||
@ -1024,19 +1024,6 @@ class LLM:
|
||||
"pooling model."
|
||||
)
|
||||
|
||||
if pooling_task not in self.supported_tasks:
|
||||
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
for param in as_iter(pooling_params):
|
||||
param.verify(pooling_task, model_config)
|
||||
# for backwards compatibility
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
io_processor_prompt = False
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
io_processor_prompt = True
|
||||
@ -1054,6 +1041,34 @@ class LLM:
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
|
||||
if io_processor_prompt:
|
||||
assert self.io_processor is not None
|
||||
if is_list_of(pooling_params, PoolingParams):
|
||||
validated_pooling_params: list[PoolingParams] = []
|
||||
for param in as_iter(pooling_params):
|
||||
validated_pooling_params.append(
|
||||
self.io_processor.validate_or_generate_params(param)
|
||||
)
|
||||
pooling_params = validated_pooling_params
|
||||
else:
|
||||
assert not isinstance(pooling_params, Sequence)
|
||||
pooling_params = self.io_processor.validate_or_generate_params(
|
||||
pooling_params
|
||||
)
|
||||
else:
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
if pooling_task not in self.supported_tasks:
|
||||
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
||||
|
||||
for param in as_iter(pooling_params):
|
||||
param.verify(pooling_task, model_config)
|
||||
# for backwards compatibility
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
|
||||
@ -1748,7 +1748,12 @@ async def init_app_state(
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
)
|
||||
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
|
||||
if (
|
||||
any(
|
||||
task in supported_tasks
|
||||
for task in ["token_embed", "token_classify", "plugin"]
|
||||
)
|
||||
)
|
||||
else None
|
||||
)
|
||||
state.openai_serving_embedding = (
|
||||
|
||||
@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual input is processed
|
||||
by the plugin itself. Hence, we use a generic type for the request data
|
||||
"""
|
||||
activation: bool = False
|
||||
|
||||
encoding_format: EncodingFormat = "float"
|
||||
embed_dtype: EmbedDType = Field(
|
||||
@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
)
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(task="token_classify", activation=self.activation)
|
||||
return PoolingParams()
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
|
||||
@ -32,7 +32,7 @@ from vllm.entrypoints.renderer import RenderConfig
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.tasks import PoolingTask, SupportedTask
|
||||
from vllm.utils.async_utils import merge_async_iterators
|
||||
from vllm.utils.serial_utils import (
|
||||
EmbedDType,
|
||||
@ -161,12 +161,21 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
# Schedule the request and get the result generator.
|
||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||
try:
|
||||
pooling_params = request.to_pooling_params()
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None and isinstance(
|
||||
request, IOProcessorRequest
|
||||
)
|
||||
pooling_params = self.io_processor.validate_or_generate_params()
|
||||
else:
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
pooling_task: PoolingTask
|
||||
if "token_embed" in self.supported_tasks:
|
||||
pooling_task = "token_embed"
|
||||
elif "token_classify" in self.supported_tasks:
|
||||
pooling_task = "token_classify"
|
||||
elif "plugin" in self.supported_tasks:
|
||||
pooling_task = "plugin"
|
||||
else:
|
||||
return self.create_error_response(
|
||||
f"pooling_task must be one of {self.supported_tasks}."
|
||||
|
||||
@ -414,6 +414,18 @@ class Pooler(nn.Module, ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyPooler(Pooler):
|
||||
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||
return {"plugin", "score"}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> PoolerOutput:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class PoolerHead(nn.Module):
|
||||
def __init__(self, activation: PoolerActivation) -> None:
|
||||
super().__init__()
|
||||
|
||||
@ -34,7 +34,7 @@ from transformers import BatchFeature
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
||||
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -249,9 +249,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.pooler = DispatchPooler(
|
||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
||||
)
|
||||
self.pooler = DispatchPooler({"plugin": DummyPooler()})
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
@ -9,6 +9,8 @@ from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
IOProcessorInput = TypeVar("IOProcessorInput")
|
||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||
@ -63,6 +65,11 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
def validate_or_generate_params(
|
||||
self, params: SamplingParams | PoolingParams | None = None
|
||||
) -> SamplingParams | PoolingParams:
|
||||
return params or PoolingParams()
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput
|
||||
|
||||
@ -84,6 +84,11 @@ class PoolingParams(
|
||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
return
|
||||
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
# which is not available in model config. So, it's not included
|
||||
# in this method
|
||||
|
||||
@ -5,7 +5,9 @@ from typing import Literal, get_args
|
||||
GenerationTask = Literal["generate", "transcription"]
|
||||
GENERATION_TASKS = get_args(GenerationTask)
|
||||
|
||||
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
|
||||
PoolingTask = Literal[
|
||||
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
|
||||
]
|
||||
POOLING_TASKS = get_args(PoolingTask)
|
||||
|
||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user