mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 02:35:26 +08:00
[Frontend][4/N] Improve all pooling task | Add plugin pooling task (#26973)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
parent
fe2016de2d
commit
3fa2c12185
@ -13,7 +13,6 @@ IOProcessorInput = TypeVar("IOProcessorInput")
|
|||||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||||
|
|
||||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
@ -49,13 +48,24 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
|||||||
request_id: str | None = None,
|
request_id: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> IOProcessorOutput:
|
) -> IOProcessorOutput:
|
||||||
collected_output = [item async for i, item in model_output]
|
# We cannot guarantee outputs are returned in the same order they were
|
||||||
|
# fed to vLLM.
|
||||||
|
# Let's sort them by id before post_processing
|
||||||
|
sorted_output = sorted(
|
||||||
|
[(i, item) async for i, item in model_output], key=lambda output: output[0]
|
||||||
|
)
|
||||||
|
collected_output = [output[1] for output in sorted_output]
|
||||||
return self.post_process(collected_output, request_id, **kwargs)
|
return self.post_process(collected_output, request_id, **kwargs)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate_or_generate_params(
|
||||||
|
self, params: SamplingParams | PoolingParams | None = None
|
||||||
|
) -> SamplingParams | PoolingParams:
|
||||||
|
return params or PoolingParams()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output_to_response(
|
def output_to_response(
|
||||||
self, plugin_output: IOProcessorOutput
|
self, plugin_output: IOProcessorOutput
|
||||||
@ -66,10 +76,10 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
|||||||
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
|
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
|
||||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||||
|
The `validate_or_generate_params` method is used for validating with the plugin any `SamplingParameters`/`PoolingParameters` received with the user request, or to generate new ones if none are specified. The function always returns the validated/generated parameters.
|
||||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/pooling` serving endpoint is available here [vllm/entrypoints/openai/serving_pooling.py](../../vllm/entrypoints/openai/serving_pooling.py).
|
||||||
|
|
||||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/IBM/terratorch/tree/main/terratorch/vllm/plugins/segmentation). Please, also refer to our online ([examples/online_serving/prithvi_geospatial_mae.py](../../examples/online_serving/prithvi_geospatial_mae.py)) and offline ([examples/offline_inference/prithvi_geospatial_mae_io_processor.py](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py)) inference examples.
|
||||||
|
|
||||||
## Using an IO Processor plugin
|
## Using an IO Processor plugin
|
||||||
|
|
||||||
|
|||||||
@ -64,7 +64,7 @@ class PrithviMAE:
|
|||||||
}
|
}
|
||||||
|
|
||||||
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
prompt = {"prompt_token_ids": [1], "multi_modal_data": mm_data}
|
||||||
outputs = self.model.encode(prompt, use_tqdm=False)
|
outputs = self.model.encode(prompt, pooling_task="plugin", use_tqdm=False)
|
||||||
|
|
||||||
return outputs[0].outputs.data
|
return outputs[0].outputs.data
|
||||||
|
|
||||||
|
|||||||
@ -6,14 +6,14 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
|
|
||||||
# This example shows how to perform an offline inference that generates
|
# This example shows how to perform an offline inference that generates
|
||||||
# multimodal data. In this specific case this example will take a geotiff
|
# multimodal data. In this specific case this example will take a geotiff
|
||||||
# image as input, process it using the multimodal data processor, and
|
# image as input, process it using the multimodal data processor, and
|
||||||
# perform inference.
|
# perform inference.
|
||||||
# Requirement - install plugin at:
|
# Requirements:
|
||||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
# - install TerraTorch v1.1 (or later):
|
||||||
|
# pip install terratorch>=v1.1
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -36,16 +36,12 @@ def main():
|
|||||||
# to avoid the model going OOM.
|
# to avoid the model going OOM.
|
||||||
# The maximum number depends on the available GPU memory
|
# The maximum number depends on the available GPU memory
|
||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
io_processor_plugin="prithvi_to_tiff",
|
io_processor_plugin="terratorch_segmentation",
|
||||||
model_impl="terratorch",
|
model_impl="terratorch",
|
||||||
enable_mm_embeds=True,
|
enable_mm_embeds=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(task="token_classify", activation=False)
|
pooler_output = llm.encode(img_prompt, pooling_task="plugin")
|
||||||
pooler_output = llm.encode(
|
|
||||||
img_prompt,
|
|
||||||
pooling_params=pooling_params,
|
|
||||||
)
|
|
||||||
output = pooler_output[0].outputs
|
output = pooler_output[0].outputs
|
||||||
|
|
||||||
print(output)
|
print(output)
|
||||||
|
|||||||
@ -11,14 +11,14 @@ import requests
|
|||||||
# image as input, process it using the multimodal data processor, and
|
# image as input, process it using the multimodal data processor, and
|
||||||
# perform inference.
|
# perform inference.
|
||||||
# Requirements :
|
# Requirements :
|
||||||
# - install plugin at:
|
# - install TerraTorch v1.1 (or later):
|
||||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
# pip install terratorch>=v1.1
|
||||||
# - start vllm in serving mode with the below args
|
# - start vllm in serving mode with the below args
|
||||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||||
# --model-impl terratorch
|
# --model-impl terratorch
|
||||||
# --task embed --trust-remote-code
|
# --task embed --trust-remote-code
|
||||||
# --skip-tokenizer-init --enforce-eager
|
# --skip-tokenizer-init --enforce-eager
|
||||||
# --io-processor-plugin prithvi_to_tiff
|
# --io-processor-plugin terratorch_segmentation
|
||||||
# --enable-mm-embeds
|
# --enable-mm-embeds
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +35,6 @@ def main():
|
|||||||
},
|
},
|
||||||
"priority": 0,
|
"priority": 0,
|
||||||
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||||
"softmax": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = requests.post(server_endpoint, json=request_payload_url)
|
ret = requests.post(server_endpoint, json=request_payload_url)
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def _run_test(
|
|||||||
max_num_seqs=32,
|
max_num_seqs=32,
|
||||||
default_torch_num_threads=1,
|
default_torch_num_threads=1,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.llm.encode(prompt, pooling_task="token_classify")
|
vllm_model.llm.encode(prompt, pooling_task="plugin")
|
||||||
|
|
||||||
|
|
||||||
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
MODELS = ["mgazz/Prithvi-EO-2.0-300M-TL-Sen1Floods11"]
|
||||||
|
|||||||
@ -368,9 +368,9 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
|||||||
out_format = "b64_json"
|
out_format = "b64_json"
|
||||||
|
|
||||||
for output in model_output:
|
for output in model_output:
|
||||||
y_hat = output.outputs.data.argmax(dim=1)
|
y_hat = output.outputs.data.argmax(dim=0)
|
||||||
pred = torch.nn.functional.interpolate(
|
pred = torch.nn.functional.interpolate(
|
||||||
y_hat.unsqueeze(1).float(),
|
y_hat[None, None, ...].float(),
|
||||||
size=self.img_size,
|
size=self.img_size,
|
||||||
mode="nearest",
|
mode="nearest",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from tests.utils import RemoteOpenAIServer
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||||
from vllm.plugins.io_processors import get_io_processor
|
from vllm.plugins.io_processors import get_io_processor
|
||||||
from vllm.pooling_params import PoolingParams
|
|
||||||
|
|
||||||
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
MODEL_NAME = "ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11"
|
||||||
|
|
||||||
@ -94,8 +93,6 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
|||||||
out_data_format="b64_json",
|
out_data_format="b64_json",
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_params = PoolingParams(activation=False)
|
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_name,
|
model_name,
|
||||||
runner="pooling",
|
runner="pooling",
|
||||||
@ -109,9 +106,7 @@ def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
|||||||
model_impl="terratorch",
|
model_impl="terratorch",
|
||||||
io_processor_plugin="prithvi_to_tiff",
|
io_processor_plugin="prithvi_to_tiff",
|
||||||
) as llm_runner:
|
) as llm_runner:
|
||||||
pooler_output = llm_runner.get_llm().encode(
|
pooler_output = llm_runner.get_llm().encode(img_prompt, pooling_task="plugin")
|
||||||
img_prompt, pooling_params=pooling_params, pooling_task="token_classify"
|
|
||||||
)
|
|
||||||
output = pooler_output[0].outputs
|
output = pooler_output[0].outputs
|
||||||
|
|
||||||
# verify the output is formatted as expected for this plugin
|
# verify the output is formatted as expected for this plugin
|
||||||
|
|||||||
@ -1024,19 +1024,6 @@ class LLM:
|
|||||||
"pooling model."
|
"pooling model."
|
||||||
)
|
)
|
||||||
|
|
||||||
if pooling_task not in self.supported_tasks:
|
|
||||||
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
|
||||||
|
|
||||||
if pooling_params is None:
|
|
||||||
# Use default pooling params.
|
|
||||||
pooling_params = PoolingParams()
|
|
||||||
|
|
||||||
for param in as_iter(pooling_params):
|
|
||||||
param.verify(pooling_task, model_config)
|
|
||||||
# for backwards compatibility
|
|
||||||
if truncate_prompt_tokens is not None:
|
|
||||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
|
||||||
|
|
||||||
io_processor_prompt = False
|
io_processor_prompt = False
|
||||||
if isinstance(prompts, dict) and "data" in prompts:
|
if isinstance(prompts, dict) and "data" in prompts:
|
||||||
io_processor_prompt = True
|
io_processor_prompt = True
|
||||||
@ -1054,6 +1041,34 @@ class LLM:
|
|||||||
# obtain the actual model prompts from the pre-processor
|
# obtain the actual model prompts from the pre-processor
|
||||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||||
|
|
||||||
|
if io_processor_prompt:
|
||||||
|
assert self.io_processor is not None
|
||||||
|
if is_list_of(pooling_params, PoolingParams):
|
||||||
|
validated_pooling_params: list[PoolingParams] = []
|
||||||
|
for param in as_iter(pooling_params):
|
||||||
|
validated_pooling_params.append(
|
||||||
|
self.io_processor.validate_or_generate_params(param)
|
||||||
|
)
|
||||||
|
pooling_params = validated_pooling_params
|
||||||
|
else:
|
||||||
|
assert not isinstance(pooling_params, Sequence)
|
||||||
|
pooling_params = self.io_processor.validate_or_generate_params(
|
||||||
|
pooling_params
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if pooling_params is None:
|
||||||
|
# Use default pooling params.
|
||||||
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
|
if pooling_task not in self.supported_tasks:
|
||||||
|
raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
|
||||||
|
|
||||||
|
for param in as_iter(pooling_params):
|
||||||
|
param.verify(pooling_task, model_config)
|
||||||
|
# for backwards compatibility
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
|
|||||||
@ -1748,7 +1748,12 @@ async def init_app_state(
|
|||||||
log_error_stack=args.log_error_stack,
|
log_error_stack=args.log_error_stack,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
|
if (
|
||||||
|
any(
|
||||||
|
task in supported_tasks
|
||||||
|
for task in ["token_embed", "token_classify", "plugin"]
|
||||||
|
)
|
||||||
|
)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
state.openai_serving_embedding = (
|
state.openai_serving_embedding = (
|
||||||
|
|||||||
@ -1707,11 +1707,6 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|||||||
if the served model does not use priority scheduling.
|
if the served model does not use priority scheduling.
|
||||||
"""
|
"""
|
||||||
data: T
|
data: T
|
||||||
"""
|
|
||||||
When using plugins IOProcessor plugins, the actual input is processed
|
|
||||||
by the plugin itself. Hence, we use a generic type for the request data
|
|
||||||
"""
|
|
||||||
activation: bool = False
|
|
||||||
|
|
||||||
encoding_format: EncodingFormat = "float"
|
encoding_format: EncodingFormat = "float"
|
||||||
embed_dtype: EmbedDType = Field(
|
embed_dtype: EmbedDType = Field(
|
||||||
@ -1732,7 +1727,7 @@ class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def to_pooling_params(self):
|
def to_pooling_params(self):
|
||||||
return PoolingParams(task="token_classify", activation=self.activation)
|
return PoolingParams()
|
||||||
|
|
||||||
|
|
||||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ from vllm.entrypoints.renderer import RenderConfig
|
|||||||
from vllm.entrypoints.utils import _validate_truncation_size
|
from vllm.entrypoints.utils import _validate_truncation_size
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
from vllm.tasks import SupportedTask
|
from vllm.tasks import PoolingTask, SupportedTask
|
||||||
from vllm.utils.async_utils import merge_async_iterators
|
from vllm.utils.async_utils import merge_async_iterators
|
||||||
from vllm.utils.serial_utils import (
|
from vllm.utils.serial_utils import (
|
||||||
EmbedDType,
|
EmbedDType,
|
||||||
@ -161,12 +161,21 @@ class OpenAIServingPooling(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
try:
|
try:
|
||||||
|
if is_io_processor_request:
|
||||||
|
assert self.io_processor is not None and isinstance(
|
||||||
|
request, IOProcessorRequest
|
||||||
|
)
|
||||||
|
pooling_params = self.io_processor.validate_or_generate_params()
|
||||||
|
else:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
pooling_task: PoolingTask
|
||||||
if "token_embed" in self.supported_tasks:
|
if "token_embed" in self.supported_tasks:
|
||||||
pooling_task = "token_embed"
|
pooling_task = "token_embed"
|
||||||
elif "token_classify" in self.supported_tasks:
|
elif "token_classify" in self.supported_tasks:
|
||||||
pooling_task = "token_classify"
|
pooling_task = "token_classify"
|
||||||
|
elif "plugin" in self.supported_tasks:
|
||||||
|
pooling_task = "plugin"
|
||||||
else:
|
else:
|
||||||
return self.create_error_response(
|
return self.create_error_response(
|
||||||
f"pooling_task must be one of {self.supported_tasks}."
|
f"pooling_task must be one of {self.supported_tasks}."
|
||||||
|
|||||||
@ -414,6 +414,18 @@ class Pooler(nn.Module, ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DummyPooler(Pooler):
|
||||||
|
def get_supported_tasks(self) -> Set[PoolingTask]:
|
||||||
|
return {"plugin", "score"}
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: list[torch.Tensor] | torch.Tensor,
|
||||||
|
pooling_metadata: PoolingMetadata,
|
||||||
|
) -> PoolerOutput:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PoolerHead(nn.Module):
|
class PoolerHead(nn.Module):
|
||||||
def __init__(self, activation: PoolerActivation) -> None:
|
def __init__(self, activation: PoolerActivation) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -34,7 +34,7 @@ from transformers import BatchFeature
|
|||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
from vllm.model_executor.layers.pooler import DispatchPooler, DummyPooler
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -249,9 +249,7 @@ class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
|
|||||||
pooler_config = vllm_config.model_config.pooler_config
|
pooler_config = vllm_config.model_config.pooler_config
|
||||||
assert pooler_config is not None
|
assert pooler_config is not None
|
||||||
|
|
||||||
self.pooler = DispatchPooler(
|
self.pooler = DispatchPooler({"plugin": DummyPooler()})
|
||||||
{"token_classify": Pooler.for_token_classify(pooler_config)}
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
def get_input_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -9,6 +9,8 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.outputs import PoolingRequestOutput
|
from vllm.outputs import PoolingRequestOutput
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
|
||||||
IOProcessorInput = TypeVar("IOProcessorInput")
|
IOProcessorInput = TypeVar("IOProcessorInput")
|
||||||
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
IOProcessorOutput = TypeVar("IOProcessorOutput")
|
||||||
@ -63,6 +65,11 @@ class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
|||||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def validate_or_generate_params(
|
||||||
|
self, params: SamplingParams | PoolingParams | None = None
|
||||||
|
) -> SamplingParams | PoolingParams:
|
||||||
|
return params or PoolingParams()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def output_to_response(
|
def output_to_response(
|
||||||
self, plugin_output: IOProcessorOutput
|
self, plugin_output: IOProcessorOutput
|
||||||
|
|||||||
@ -84,6 +84,11 @@ class PoolingParams(
|
|||||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
# plugin task uses io_processor.parse_request to verify inputs,
|
||||||
|
# skipping PoolingParams verify
|
||||||
|
if self.task == "plugin":
|
||||||
|
return
|
||||||
|
|
||||||
# NOTE: Task validation needs to done against the model instance,
|
# NOTE: Task validation needs to done against the model instance,
|
||||||
# which is not available in model config. So, it's not included
|
# which is not available in model config. So, it's not included
|
||||||
# in this method
|
# in this method
|
||||||
|
|||||||
@ -5,7 +5,9 @@ from typing import Literal, get_args
|
|||||||
GenerationTask = Literal["generate", "transcription"]
|
GenerationTask = Literal["generate", "transcription"]
|
||||||
GENERATION_TASKS = get_args(GenerationTask)
|
GENERATION_TASKS = get_args(GenerationTask)
|
||||||
|
|
||||||
PoolingTask = Literal["embed", "classify", "score", "token_embed", "token_classify"]
|
PoolingTask = Literal[
|
||||||
|
"embed", "classify", "score", "token_embed", "token_classify", "plugin"
|
||||||
|
]
|
||||||
POOLING_TASKS = get_args(PoolingTask)
|
POOLING_TASKS = get_args(PoolingTask)
|
||||||
|
|
||||||
SupportedTask = Literal[GenerationTask, PoolingTask]
|
SupportedTask = Literal[GenerationTask, PoolingTask]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user