[Frontend][4/N] Improve all pooling task | Add plugin pooling task (#26973)

Signed-off-by: wang.yuqi <noooop@126.com>
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
wang.yuqi 2025-10-23 22:46:18 +08:00 committed by GitHub
parent fe2016de2d
commit 3fa2c12185
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 102 additions and 54 deletions

View File

@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
# We cannot guarantee outputs are returned in the same order they were
# fed to vLLM.
# Let's sort them by id before post_processing
sorted_output = sorted(
[(i, item) async for i, item in model_output], key=lambda output: output[0]
)
collected_output = [output[1] for output in sorted_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput
@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
## Using an IO Processor plugin

View File

@ -64,7 +64,7 @@ class PrithviMAE:
}
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
outputs = self.model.encode(prompt, use_tqdm=False)
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
return outputs[0].outputs.data

View File

@ -6,14 +6,14 @@ import os
import torch
from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# Requirements:
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
def main():
@ -36,16 +36,12 @@ def main():
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff",
io_processor_plugin="terratorch_segmentation",
model_impl="terratorch",
enable_mm_embeds=True,
)
pooling_params = PoolingParams(task="token_classify", activation=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs
print(output)

View File

@ -11,14 +11,14 @@ import requests
# image as input, process it using the multimodal data processor, and
# perform inference.
# Requirements :
# - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - install TerraTorch v1.1 (or later):
# pip install terratorch>=v1.1
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --model-impl terratorch
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff
# --io-processor-plugin terratorch_segmentation
# --enable-mm-embeds
@ -35,7 +35,6 @@ def main():
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
"softmax": False,
}
ret = requests.post(server_endpoint, json=request_payload_url)

View File

@ -40,7 +40,7 @@ def _run_test(
max_num_seqs=32,
default_torch_num_threads=1,
) as vllm_model:
vllm_model.llm.encode(prompt, pooling_task="token_classify")
vllm_model.llm.encode(prompt, pooling_task="plugin")
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]

View File

@ -368,9 +368,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
out_format = "b64_json"
for output in model_output:
y_hat = output.outputs.data.argmax(dim=1)
y_hat = output.outputs.data.argmax(dim=0)
pred = torch.nn.functional.interpolate(
y_hat.unsqueeze(1).float(),
y_hat[None, None, ...].float(),
size=self.img_size,
mode="nearest",
)

View File

@ -9,7 +9,6 @@ from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
out_data_format="b64_json",
)
pooling_params = PoolingParams(activation=False)
with vllm_runner(
model_name,
runner="pooling",
@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
model_impl="terratorch",
io_processor_plugin="prithvi_to_tiff",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
)
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin

View File

@ -1024,19 +1024,6 @@ class LLM:
"pooling model."
)
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
@ -1054,6 +1041,34 @@ class LLM:
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
if io_processor_prompt:
assert self.io_processor is not None
if is_list_of(pooling_params, PoolingParams):
validated_pooling_params: list[PoolingParams] = []
for param in as_iter(pooling_params):
validated_pooling_params.append(
self.io_processor.validate_or_generate_params(param)
)
pooling_params = validated_pooling_params
else:
assert not isinstance(pooling_params, Sequence)
pooling_params = self.io_processor.validate_or_generate_params(
pooling_params
)
else:
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
if pooling_task not in self.supported_tasks:
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
for param in as_iter(pooling_params):
param.verify(pooling_task, model_config)
# for backwards compatibility
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
self._validate_and_add_requests(
prompts=prompts,
params=pooling_params,

View File

@ -1748,7 +1748,12 @@ async def init_app_state(
log_error_stack=args.log_error_stack,
)
)
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
if (
any(
task in supported_tasks
for task in ["token_embed", "token_classify", "plugin"]
)
)
else None
)
state.openai_serving_embedding = (

View File

@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
if the served model does not use priority scheduling.
"""
data: T
"""
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
"""
activation: bool = False
encoding_format: EncodingFormat = "float"
embed_dtype: EmbedDType = Field(
@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
)
def to_pooling_params(self):
return PoolingParams(task="token_classify", activation=self.activation)
return PoolingParams()
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):

View File

@ -32,7 +32,7 @@ from vllm.entrypoints.renderer import RenderConfig
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.tasks import SupportedTask
from vllm.tasks import PoolingTask, SupportedTask
from vllm.utils.async_utils import merge_async_iterators
from vllm.utils.serial_utils import (
EmbedDType,
@ -161,12 +161,21 @@ class OpenAIServingPooling(OpenAIServing):
# Schedule the request and get the result generator.
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
try:
pooling_params = request.to_pooling_params()
if is_io_processor_request:
assert self.io_processor is not None and isinstance(
request, IOProcessorRequest
)
pooling_params = self.io_processor.validate_or_generate_params()
else:
pooling_params = request.to_pooling_params()
pooling_task: PoolingTask
if "token_embed" in self.supported_tasks:
pooling_task = "token_embed"
elif "token_classify" in self.supported_tasks:
pooling_task = "token_classify"
elif "plugin" in self.supported_tasks:
pooling_task = "plugin"
else:
return self.create_error_response(
f"pooling_task must be one of {self.supported_tasks}."

View File

@ -414,6 +414,18 @@ class Pooler(nn.Module, ABC):
raise NotImplementedError
class DummyPooler(Pooler):
def get_supported_tasks(self) -> Set[PoolingTask]:
return {"plugin", "score"}
def forward(
self,
hidden_states: list[torch.Tensor] | torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
return hidden_states
class PoolerHead(nn.Module):
def __init__(self, activation: PoolerActivation) -> None:
super().__init__()

View File

@ -34,7 +34,7 @@ from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import AutoWeightsLoader
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -249,9 +249,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.pooler = DispatchPooler(
{"token_classify": Pooler.for_token_classify(pooler_config)}
)
self.pooler = DispatchPooler({"plugin": DummyPooler()})
def get_input_embeddings(
self,

View File

@ -9,6 +9,8 @@ from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
IOProcessorInput = TypeVar("IOProcessorInput")
IOProcessorOutput = TypeVar("IOProcessorOutput")
@ -63,6 +65,11 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
def validate_or_generate_params(
self, params: SamplingParams | PoolingParams | None = None
) -> SamplingParams | PoolingParams:
return params or PoolingParams()
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput

View File

@ -84,6 +84,11 @@ class PoolingParams(
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
raise ValueError(msg)
# plugin task uses io_processor.parse_request to verify inputs,
# skipping PoolingParams verify
if self.task == "plugin":
return
# NOTE: Task validation needs to done against the model instance,
# which is not available in model config. So, it's not included
# in this method

View File

@ -5,7 +5,9 @@ from typing import Literal, get_args
GenerationTask = Literal["generate", "transcription"]
GENERATION_TASKS = get_args(GenerationTask)
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
PoolingTask = Literal[
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
]
POOLING_TASKS = get_args(PoolingTask)
SupportedTask = Literal[GenerationTask, PoolingTask]