diff --git a/docs/design/io_processor_plugins.md b/docs/design/io_processor_plugins.md index 8e5d5249409c6..ee474b5a7b997 100644 --- a/docs/design/io_processor_plugins.md +++ b/docs/design/io_processor_plugins.md @@ -64,9 +64,9 @@ The `parse_request` method is used for validating the user prompt and converting The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference. The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output. -The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py). +The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is available here . -An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples. +An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online () and offline () inference examples. ## Using an IO Processor plugin diff --git a/examples/online_serving/prithvi_geospatial_mae.py b/examples/online_serving/prithvi_geospatial_mae.py index cbd34f461362c..31301e0042cf4 100644 --- a/examples/online_serving/prithvi_geospatial_mae.py +++ b/examples/online_serving/prithvi_geospatial_mae.py @@ -33,6 +33,7 @@ def main(): }, "priority": 0, "model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM", + "softmax": False, } ret = requests.post(server_endpoint, json=request_payload_url) diff --git a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py index d49a50b7a309f..0ebaafda94dc5 100644 --- a/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py +++ b/tests/plugins/prithvi_io_processor_plugin/prithvi_io_processor/prithvi_processor.py @@ -8,7 +8,7 @@ import datetime import os import tempfile import urllib.request -from collections.abc import AsyncGenerator, Sequence +from collections.abc import Sequence from typing import Any, Optional, Union import albumentations @@ -359,14 +359,6 @@ class PrithviMultimodalDataProcessor(IOProcessor): return prompts - async def pre_process_async( - self, - prompt: IOProcessorInput, - request_id: Optional[str] = None, - **kwargs, - ) -> Union[PromptType, Sequence[PromptType]]: - return self.pre_process(prompt, request_id, **kwargs) - def post_process( self, model_output: Sequence[PoolingRequestOutput], @@ -421,15 +413,6 @@ class PrithviMultimodalDataProcessor(IOProcessor): data=out_data, request_id=request_id) - async def post_process_async( - self, - model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]], - request_id: Optional[str] = None, - **kwargs, - ) -> IOProcessorOutput: - collected_output = [item async for i, item in model_output] - return self.post_process(collected_output, request_id, **kwargs) - class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor): diff --git a/tests/plugins_tests/test_io_processor_plugins.py b/tests/plugins_tests/test_io_processor_plugins.py index 00fe429445d7d..b2fbef2ee25cb 100644 --- a/tests/plugins_tests/test_io_processor_plugins.py +++ b/tests/plugins_tests/test_io_processor_plugins.py @@ -113,6 +113,7 @@ async def test_prithvi_mae_plugin_online( }, "priority": 0, "model": model_name, + "softmax": False } ret = requests.post( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 00b72f74cec85..30c3a82696155 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1424,9 +1424,10 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]): When using plugins IOProcessor plugins, the actual input is processed by the plugin itself. Hence, we use a generic type for the request data """ + softmax: bool = True def to_pooling_params(self): - return PoolingParams(task="encode") + return PoolingParams(task="encode", softmax=self.softmax) class IOProcessorResponse(OpenAIBaseModel, Generic[T]): diff --git a/vllm/plugins/io_processors/interface.py b/vllm/plugins/io_processors/interface.py index 5c73188d5df51..62b224cac5e53 100644 --- a/vllm/plugins/io_processors/interface.py +++ b/vllm/plugins/io_processors/interface.py @@ -49,7 +49,12 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): request_id: Optional[str] = None, **kwargs, ) -> IOProcessorOutput: - collected_output = [item async for i, item in model_output] + # We cannot guarantee outputs are returned in the same order they were + # fed to vLLM. + # Let's sort them by id before post_processing + sorted_output = sorted([(i, item) async for i, item in model_output], + key=lambda output: output[0]) + collected_output = [output[1] for output in sorted_output] return self.post_process(collected_output, request_id, **kwargs) @abstractmethod @@ -59,4 +64,4 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]): @abstractmethod def output_to_response( self, plugin_output: IOProcessorOutput) -> IOProcessorResponse: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError