Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-09-01 01:45:29 -07:00
commit 4c2a337e67
49 changed files with 1874 additions and 313 deletions

View File

@ -566,8 +566,7 @@ steps:
- tests/models/multimodal
commands:
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
- pytest -v -s models/multimodal/processing
- label: Multi-Modal Models Test (Standard)
mirror_hardwares: [amdexperimental]
@ -770,6 +769,11 @@ steps:
- pytest -v -s plugins_tests/test_platform_plugins.py
- pip uninstall vllm_add_dummy_platform -y
# end platform plugin tests
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
- pip install -e ./plugins/prithvi_io_processor_plugin
- pytest -v -s plugins_tests/test_io_processor_plugins.py
- pip uninstall prithvi_io_processor_plugin -y
# end io_processor plugins test
# other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model

View File

@ -0,0 +1,78 @@
# IO Processor Plugins
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
## Writing an IO Processor Plugin
IO Processor plugins implement the `IOProcessor` interface (<gh-file:vllm/plugins/io_processors/interface.py>):
```python
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError
```
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py).
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples.
## Using an IO Processor plugin
IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded:
1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode.
2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json).
The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json).

View File

@ -49,6 +49,8 @@ Every plugin has three parts:
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name.
## Guidelines for Writing Plugins
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import torch
from vllm import LLM
from vllm.pooling_params import PoolingParams
# This example shows how to perform an offline inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Reuirement - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
def main():
torch.set_default_dtype(torch.float16)
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
llm = LLM(
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM.
# The maximum number depends on the available GPU memory
max_num_seqs=32,
io_processor_plugin="prithvi_to_tiff_india",
)
pooling_params = PoolingParams(task="encode", softmax=False)
pooler_output = llm.encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
print(output)
decoded_data = base64.b64decode(output.data)
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
with open(file_path, "wb") as f:
f.write(decoded_data)
print(f"Output file path: {file_path}")
if __name__ == "__main__":
main()

View File

@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import os
import requests
# This example shows how to perform an online inference that generates
# multimodal data. In this specific case this example will take a geotiff
# image as input, process it using the multimodal data processor, and
# perform inference.
# Reuirements :
# - install plugin at:
# https://github.com/christian-pinto/prithvi_io_processor_plugin
# - start vllm in serving mode with the below args
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
# --task embed --trust-remote-code
# --skip-tokenizer-init --enforce-eager
# --io-processor-plugin prithvi_to_tiff_india
def main():
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
server_endpoint = "http://localhost:8000/pooling"
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
}
ret = requests.post(server_endpoint, json=request_payload_url)
print(f"response.status_code: {ret.status_code}")
print(f"response.reason:{ret.reason}")
response = ret.json()
decoded_image = base64.b64decode(response["data"]["data"])
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
with open(out_path, "wb") as f:
f.write(decoded_image)
if __name__ == "__main__":
main()

View File

@ -1120,6 +1120,9 @@ class VllmRunner:
return self.llm.llm_engine.collective_rpc(_apply_model)
def get_llm(self) -> LLM:
return self.llm
def __enter__(self):
return self

View File

@ -1,30 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial
from typing import Any, Union
from unittest.mock import patch
import numpy as np
import pytest
import torch.nn as nn
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
UserMessage)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from PIL import Image
from vllm.config import ModelConfig
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment,
initialize_model_parallel)
from vllm.inputs import InputProcessingContext
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalKwargs)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.engine.core import EngineCore as V1EngineCore
from vllm.utils import is_list_of
from ....conftest import VllmRunner
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides
@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
return group_mm_kwargs_by_modality(items)
@contextmanager
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
temp_file = tempfile.mkstemp()[1]
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"file://{temp_file}",
local_rank=0,
backend="nccl",
)
initialize_model_parallel(tensor_model_parallel_size=1)
vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype):
model = model_cls(vllm_config=vllm_config)
yield model
del model
cleanup_dist_env_and_memory()
def get_model_id_to_test(
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
filtered_results = []
@ -155,8 +177,7 @@ def get_model_id_to_test(
@pytest.mark.parametrize(
"model_arch, model_id",
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
def test_model_tensor_schema(model_arch: str, model_id: str,
vllm_runner: type[VllmRunner], monkeypatch):
def test_model_tensor_schema(model_arch: str, model_id: str):
if model_arch in ARCH_TO_SKIP:
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
if model_id in REPO_ID_TO_SKIP:
@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
hf_overrides=hf_overrides_fn,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
if not any(
hasattr(model_cls, f"_parse_and_validate_{m}_input")
for m in ["image", "video", "audio"]):
inputs_parse_methods = []
for attr_name in dir(model_cls):
attr = getattr(model_cls, attr_name)
if hasattr(attr, "__annotations__"):
return_type = attr.__annotations__.get("return", None)
if return_type is not None and "Input" in str(return_type):
inputs_parse_methods.append(attr_name)
if not any(inputs_parse_methods):
pytest.skip(f"{model_arch} does not support tensor schema validation.")
ctx = InputProcessingContext(
@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
modality: 3 if limit is None else limit
for modality, limit in supported_mm_limits.items()
}
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
processor = factories.build_processor(ctx, cache=None)
# Avoid calling model.forward()
def _initialize_kv_caches_v0(self) -> None:
self.cache_config.num_gpu_blocks = 0
self.cache_config.num_cpu_blocks = 0
def _initialize_kv_caches_v1(self, vllm_config):
kv_cache_specs = self.model_executor.get_kv_cache_specs()
scheduler_kv_cache_config = get_kv_cache_config(
vllm_config,
kv_cache_specs[0],
10 * GiB_bytes,
)
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
return 1, 0, scheduler_kv_cache_config
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
_initialize_kv_caches_v0),
patch.object(V1EngineCore, "_initialize_kv_caches",
_initialize_kv_caches_v1), monkeypatch.context() as m):
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
if model_info.v0_only:
m.setenv("VLLM_USE_V1", "0")
# TODO(Isotr0py): Can we avoid initializing engine?
with (
set_default_torch_num_threads(1),
vllm_runner(
model_id,
tokenizer_name=model_info.tokenizer,
tokenizer_mode=model_info.tokenizer_mode,
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
max_model_len=model_info.max_model_len,
load_format="dummy",
hf_overrides=hf_overrides_fn,
limit_mm_per_prompt=limit_mm_per_prompt,
enforce_eager=True,
) as vllm_model,
):
model_config = vllm_model.llm.llm_engine.model_config
llm_engine = vllm_model.llm.llm_engine
if hasattr(llm_engine, "processor"):
# v1 processor
mm_registry = llm_engine.processor.mm_registry
else:
# v0 input_preprocessor
mm_registry = llm_engine.input_preprocessor.mm_registry
processor = mm_registry.create_processor(model_config)
def validate_model_input(model, modality: str,
mm_kwargs: MultiModalKwargs):
method_name = f"_parse_and_validate_{modality}_input"
if hasattr(model, method_name):
getattr(model, method_name)(**mm_kwargs)
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
valid_func = partial(validate_model_input,
modality=modality,
mm_kwargs=mm_kwargs)
vllm_model.apply_model(valid_func)
with initialize_dummy_model(model_cls, model_config) as model:
for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor):
for method_name in inputs_parse_methods:
print(f"Testing `{method_name}` with modality={modality} "
f"and mm_kwargs{list(mm_kwargs.keys())}")
getattr(model, method_name)(modality=modality, **mm_kwargs)

View File

@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
def register_prithvi_india():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
def register_prithvi_valencia():
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501

View File

@ -0,0 +1,449 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import base64
import datetime
import os
import tempfile
import urllib.request
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Optional, Union
import albumentations
import numpy as np
import rasterio
import regex as re
import torch
from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
IOProcessorResponse)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (IOProcessor,
IOProcessorInput,
IOProcessorOutput)
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
logger = init_logger(__name__)
NO_DATA = -9999
NO_DATA_FLOAT = 0.0001
OFFSET = 0
PERCENTILE = 99
DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
datamodule_config: DataModuleConfig = {
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
"batch_size":
16,
"constant_scale":
0.0001,
"data_root":
"/dccstor/geofm-finetuning/datasets/sen1floods11",
"drop_last":
True,
"no_data_replace":
0.0,
"no_label_replace":
-1,
"num_workers":
8,
"test_transform": [
albumentations.Resize(always_apply=False,
height=448,
interpolation=1,
p=1,
width=448),
albumentations.pytorch.ToTensorV2(transpose_mask=False,
always_apply=True,
p=1.0),
],
}
def save_geotiff(image: torch.Tensor, meta: dict,
out_format: str) -> str | bytes:
"""Save multi-band image in Geotiff file.
Args:
image: np.ndarray with shape (bands, height, width)
output_path: path where to save the image
meta: dict with meta info.
"""
if out_format == "path":
# create temp file
file_path = os.path.join(os.getcwd(), "prediction.tiff")
with rasterio.open(file_path, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
return file_path
elif out_format == "b64_json":
with tempfile.NamedTemporaryFile() as tmpfile:
with rasterio.open(tmpfile.name, "w", **meta) as dest:
for i in range(image.shape[0]):
dest.write(image[i, :, :], i + 1)
file_data = tmpfile.read()
return base64.b64encode(file_data)
else:
raise ValueError("Unknown output format")
def _convert_np_uint8(float_image: torch.Tensor):
image = float_image.numpy() * 255.0
image = image.astype(dtype=np.uint8)
return image
def read_geotiff(
file_path: Optional[str] = None,
path_type: Optional[str] = None,
file_data: Optional[bytes] = None,
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
"""Read all bands from *file_path* and return image + meta info.
Args:
file_path: path to image file.
Returns:
np.ndarray with shape (bands, height, width)
meta info dict
"""
if all([x is None for x in [file_path, path_type, file_data]]):
raise Exception("All input fields to read_geotiff are None")
write_to_file: Optional[bytes] = None
path: Optional[str] = None
if file_data is not None:
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data)
# path = tmpfile.name
write_to_file = file_data
elif file_path is not None and path_type == "url":
resp = urllib.request.urlopen(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(resp.read())
# path = tmpfile.name
write_to_file = resp.read()
elif file_path is not None and path_type == "path":
path = file_path
elif file_path is not None and path_type == "b64_json":
image_data = base64.b64decode(file_path)
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(image_data)
# path = tmpfile.name
write_to_file = image_data
else:
raise Exception("Wrong combination of parameters to read_geotiff")
with tempfile.NamedTemporaryFile() as tmpfile:
path_to_use = None
if write_to_file:
tmpfile.write(write_to_file)
path_to_use = tmpfile.name
elif path:
path_to_use = path
with rasterio.open(path_to_use) as src:
img = src.read()
meta = src.meta
try:
coords = src.lnglat()
except Exception:
# Cannot read coords
coords = None
return img, meta, coords
def load_image(
data: Union[list[str]],
path_type: str,
mean: Optional[list[float]] = None,
std: Optional[list[float]] = None,
indices: Optional[Union[list[int], None]] = None,
):
"""Build an input example by loading images in *file_paths*.
Args:
file_paths: list of file paths .
mean: list containing mean values for each band in the
images in *file_paths*.
std: list containing std values for each band in the
images in *file_paths*.
Returns:
np.array containing created example
list of meta info for each image in *file_paths*
"""
imgs = []
metas = []
temporal_coords = []
location_coords = []
for file in data:
# if isinstance(file, bytes):
# img, meta, coords = read_geotiff(file_data=file)
# else:
img, meta, coords = read_geotiff(file_path=file, path_type=path_type)
# Rescaling (don't normalize on nodata)
img = np.moveaxis(img, 0, -1) # channels last for rescaling
if indices is not None:
img = img[..., indices]
if mean is not None and std is not None:
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
imgs.append(img)
metas.append(meta)
if coords is not None:
location_coords.append(coords)
try:
match = re.search(r"(\d{7,8}T\d{6})", file)
if match:
year = int(match.group(1)[:4])
julian_day = match.group(1).split("T")[0][4:]
if len(julian_day) == 3:
julian_day = int(julian_day)
else:
julian_day = (datetime.datetime.strptime(
julian_day, "%m%d").timetuple().tm_yday)
temporal_coords.append([year, julian_day])
except Exception:
logger.exception("Could not extract timestamp for %s", file)
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
imgs = np.expand_dims(imgs, axis=0) # add batch di
return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.datamodule = Sen1Floods11NonGeoDataModule(
data_root=datamodule_config["data_root"],
batch_size=datamodule_config["batch_size"],
num_workers=datamodule_config["num_workers"],
bands=datamodule_config["bands"],
drop_last=datamodule_config["drop_last"],
test_transform=datamodule_config["test_transform"],
)
self.img_size = 512
self.h1 = 1
self.w1 = 1
self.original_h = 512
self.original_w = 512
self.batch_size = 1
self.meta_data = None
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError(
"missing 'data' field in OpenAIBaseModel Request")
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt)
if request_id:
self.requests_cache[request_id] = {
"out_format": image_data["out_data_format"],
}
input_data, temporal_coords, location_coords, meta_data = load_image(
data=[image_data["data"]],
indices=self.indices,
path_type=image_data["data_format"],
)
self.meta_data = meta_data[0]
if input_data.mean() > 1:
input_data = input_data / 10000 # Convert to range 0-1
self.original_h, self.original_w = input_data.shape[-2:]
pad_h = (self.img_size -
(self.original_h % self.img_size)) % self.img_size
pad_w = (self.img_size -
(self.original_w % self.img_size)) % self.img_size
input_data = np.pad(
input_data,
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
mode="reflect",
)
batch = torch.tensor(input_data)
windows = batch.unfold(3, self.img_size,
self.img_size).unfold(4, self.img_size,
self.img_size)
self.h1, self.w1 = windows.shape[3:5]
windows = rearrange(
windows,
"b c t h1 w1 h w -> (b h1 w1) c t h w",
h=self.img_size,
w=self.img_size,
)
# Split into batches if number of windows > batch_size
num_batches = (windows.shape[0] // self.batch_size
if windows.shape[0] > self.batch_size else 1)
windows = torch.tensor_split(windows, num_batches, dim=0)
if temporal_coords:
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
else:
temporal_coords = None
if location_coords:
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
else:
location_coords = None
prompts = []
for window in windows:
# Apply standardization
window = self.datamodule.test_transform(
image=window.squeeze().numpy().transpose(1, 2, 0))
window = self.datamodule.aug(window)["image"]
prompts.append({
"prompt_token_ids": [1],
"multi_modal_data": {
"pixel_values": window.to(torch.float16)[0],
"location_coords": location_coords.to(torch.float16),
},
})
return prompts
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
out_format = self.requests_cache[request_id]["out_format"]
else:
out_format = "b64_json"
for output in model_output:
y_hat = output.outputs.data.argmax(dim=1)
pred = torch.nn.functional.interpolate(
y_hat.unsqueeze(1).float(),
size=self.img_size,
mode="nearest",
)
pred_imgs_list.append(pred)
pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0)
# Build images from patches
pred_imgs = rearrange(
pred_imgs,
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
h=self.img_size,
w=self.img_size,
b=1,
c=1,
h1=self.h1,
w1=self.w1,
)
# Cut padded area back to original size
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
# Squeeze (batch size 1)
pred_imgs = pred_imgs[0]
if not self.meta_data:
raise ValueError("No metadata available for the current task")
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
out_format)
return ImageRequestOutput(type=out_format,
format="tiff",
data=out_data,
request_id=request_id)
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [1, 2, 3, 8, 11, 12]
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
def __init__(self, vllm_config: VllmConfig):
super().__init__(vllm_config)
self.indices = [0, 1, 2, 3, 4, 5]

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Literal, Optional, TypedDict, Union
import albumentations
from pydantic import BaseModel
class DataModuleConfig(TypedDict):
bands: list[str]
batch_size: int
constant_scale: float
data_root: str
drop_last: bool
no_data_replace: float
no_label_replace: int
num_workers: int
test_transform: list[
albumentations.core.transforms_interface.BasicTransform]
class ImagePrompt(BaseModel):
data_format: Literal["b64_json", "bytes", "url"]
"""
This is the data type for the input image
"""
image_format: str
"""
This is the image format (e.g., jpeg, png, etc.)
"""
out_data_format: Literal["b64_json", "url"]
data: Any
"""
Input image data
"""
MultiModalPromptType = Union[ImagePrompt]
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
Args:
type (str): The data content type [path, object]
format (str): The image format (e.g., jpeg, png, etc.)
data (Any): The resulting data.
"""
type: Literal["path", "b64_json"]
format: str
data: str
request_id: Optional[str] = None

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from setuptools import setup
setup(
name="prithvi_io_processor_plugin",
version="0.1",
packages=["prithvi_io_processor"],
entry_points={
"vllm.io_processor_plugins": [
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
]
},
)

View File

@ -1,12 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')

View File

@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import pytest
import requests
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.llm import LLM
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
def test_loading_missing_plugin():
vllm_config = VllmConfig()
with pytest.raises(ValueError):
get_io_processor(vllm_config, "wrong_plugin")
def test_loading_engine_with_wrong_plugin():
with pytest.raises(ValueError):
LLM(
model=MODEL_NAME,
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin="wrong_plugin",
)
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
img_prompt = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
pooling_params = PoolingParams(task="encode", softmax=False)
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
trust_remote_code=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=1,
io_processor_plugin="prithvi_to_tiff_valencia",
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(
img_prompt,
pooling_params=pooling_params,
)
output = pooler_output[0].outputs
# verify the output is formatted as expected for this plugin
assert all(
hasattr(output, attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(output.data)
@pytest.fixture(scope="module")
def server():
args = [
"--runner",
"pooling",
"--enforce-eager",
"--trust-remote-code",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
"prithvi_to_tiff_valencia"
]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
):
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": model_name,
}
ret = requests.post(
server.url_for("pooling"),
json=request_payload_url,
)
response = ret.json()
# verify the request response is in the correct format
assert (parsed_response := IOProcessorResponse(**response))
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(
plugin_data.get(attr)
for attr in ["type", "format", "data", "request_id"])
# We just check that the output is a valid base64 string.
# Raises an exception and fails the test if the string is corrupted.
base64.b64decode(plugin_data["data"])

View File

@ -7,6 +7,15 @@ import torch
from vllm.plugins import load_general_plugins
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
"""
Since this module is V0 only, set VLLM_USE_V1=0 for
all tests in the module.
"""
monkeypatch.setenv('VLLM_USE_V1', '0')
def test_platform_plugins():
# simulate workload by running an example
import runpy

View File

@ -2,12 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from collections.abc import Generator
from typing import Optional
import pytest
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage, FunctionCall,
ToolCall)
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.transformers_utils.detokenizer import detokenize_incrementally
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
# Use a common model that is likely to be available
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
assert actual_tool_call.function == expected_tool_call.function
def stream_delta_message_generator(
xlam_tool_parser: xLAMToolParser,
xlam_tokenizer: AnyTokenizer,
model_output: str,
request: Optional[ChatCompletionRequest] = None,
) -> Generator[DeltaMessage, None, None]:
all_token_ids = xlam_tokenizer.encode(model_output,
add_special_tokens=False)
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[:i + 1]
(new_tokens, delta_text, new_prefix_offset,
new_read_offset) = (detokenize_incrementally(
tokenizer=xlam_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=False,
spaces_between_special_tokens=True,
))
current_text = previous_text + delta_text
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=request,
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (previous_tokens +
new_tokens if previous_tokens else new_tokens)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
def test_extract_tool_calls_no_tools(xlam_tool_parser):
model_output = "This is a test"
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
],
"I'll check the weather for you.",
),
(
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I'll help you check the weather.",
),
],
)
def test_extract_tool_calls(xlam_tool_parser, model_output,
@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
assert hasattr(result, "tool_calls")
assert len(result.tool_calls) == 1
assert result.tool_calls[0].function.name == "get_current_weather"
@pytest.mark.parametrize(
ids=[
"parallel_tool_calls",
"single_tool_with_think_tag",
"single_tool_with_json_code_block",
"single_tool_with_tool_calls_tag",
"single_tool_with_tool_call_xml_tags",
],
argnames=["model_output", "expected_tool_calls", "expected_content"],
argvalues=[
(
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
)),
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Orlando",
"state": "FL",
"unit": "fahrenheit",
}),
)),
],
"",
),
(
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"<think>I'll help you with that.</think>",
),
(
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"",
),
(
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
[
ToolCall(function=FunctionCall(
name="get_current_weather",
arguments=json.dumps({
"city": "Dallas",
"state": "TX",
"unit": "fahrenheit",
}),
))
],
"I can help with that.",
),
],
)
def test_extract_tool_calls_streaming_incremental(
xlam_tool_parser,
xlam_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
chunks = []
for delta_message in stream_delta_message_generator(
xlam_tool_parser, xlam_tokenizer, model_output, request):
chunks.append(delta_message)
# Should have multiple chunks
assert len(chunks) >= 3
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
header_found = False
expected_first_tool = expected_tool_calls[0]
for chunk in chunks:
if chunk.tool_calls and chunk.tool_calls[0].id:
header_found = True
assert (chunk.tool_calls[0].function.name ==
expected_first_tool.function.name)
assert chunk.tool_calls[0].type == "function"
# Arguments may be empty initially or None
if chunk.tool_calls[0].function.arguments is not None:
# If present, should be empty string initially
assert chunk.tool_calls[0].function.arguments == ""
break
assert header_found
# Should have chunks with incremental arguments
arg_chunks = []
for chunk in chunks:
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
and chunk.tool_calls[0].function.arguments != ""
and chunk.tool_calls[0].index ==
0 # Only collect arguments from the first tool call
):
arg_chunks.append(chunk.tool_calls[0].function.arguments)
# Arguments should be streamed incrementally
assert len(arg_chunks) > 1
# Concatenated arguments should form valid JSON for the first tool call
full_args = "".join(arg_chunks)
parsed_args = json.loads(full_args)
expected_args = json.loads(expected_first_tool.function.arguments)
assert parsed_args == expected_args

View File

@ -501,6 +501,8 @@ class ModelConfig:
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
"""One or more logits processors' fully-qualified class names or class
definitions"""
io_processor_plugin: Optional[str] = None
"""IOProcessor plugin name to load at model startup"""
def compute_hash(self) -> str:
"""

View File

@ -40,16 +40,21 @@ class KVCacheEvent(
"""Base class for all KV cache-related events"""
MEDIUM_GPU = "GPU"
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
medium: Optional[str]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
medium: Optional[str]
class AllBlocksCleared(KVCacheEvent):

View File

@ -19,6 +19,8 @@ The class provides the following primitives:
Returns whether KV cache should be freed now or will be
freed asynchronously and optionally returns KV transfer
params.
take_events() - returns new KV events that were collected
by the connector since the last call.
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
@ -34,6 +36,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
import torch
@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
"""
return False, None
def take_events(self) -> Iterable["KVCacheEvent"]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return ()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:

View File

@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from collections.abc import Iterable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
import torch
from vllm.config import KVTransferConfig, VllmConfig
from vllm.distributed.kv_events import KVCacheEvent
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
return async_saves > 0, kv_txfer_params
def take_events(self) -> Iterable[KVCacheEvent]:
for c in self._connectors:
yield from c.take_events()
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:

View File

@ -364,6 +364,7 @@ class EngineArgs:
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
io_processor_plugin: Optional[str] = None
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
enable_lora: bool = False
@ -577,6 +578,8 @@ class EngineArgs:
**model_kwargs["override_attention_dtype"])
model_group.add_argument("--logits-processors",
**model_kwargs["logits_processors"])
model_group.add_argument("--io-processor-plugin",
**model_kwargs["io_processor_plugin"])
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
@ -993,6 +996,7 @@ class EngineArgs:
model_impl=self.model_impl,
override_attention_dtype=self.override_attention_dtype,
logits_processors=self.logits_processors,
io_processor_plugin=self.io_processor_plugin,
)
def validate_tensorizer_args(self):
@ -1432,17 +1436,6 @@ class EngineArgs:
recommend_to_remove=True)
return False
# Triton v3.3 has f16 conversion regression issue on Turing and Volta,
# which broke fp16 inference
# see: https://github.com/triton-lang/triton/issues/6698
if (current_platform.is_cuda()
and not current_platform.has_device_capability(80)
and model_config.dtype == torch.float16):
_raise_or_fallback(
feature_name="Compute Capability < 8.0 with FP16",
recommend_to_remove=False)
return False
if self.kv_cache_dtype != "auto":
supported = current_platform.is_kv_cache_dtype_supported(
self.kv_cache_dtype, model_config)

View File

@ -15,6 +15,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
@ -267,6 +268,9 @@ class EngineClient(ABC):
"""Get the appropriate tokenizer for the request"""
...
async def get_io_processor(self) -> IOProcessor:
raise NotImplementedError
@abstractmethod
async def is_tracing_enabled(self) -> bool:
...

View File

@ -37,13 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
# yapf: enable
from vllm.entrypoints.utils import (_validate_truncation_size,
log_non_default_args)
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
from vllm.plugins.io_processors import get_io_processor
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
SamplingParams)
@ -284,6 +286,11 @@ class LLM:
self.supported_tasks = supported_tasks
# Load the Input/Output processor plugin if any
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
io_processor_plugin)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
@ -833,7 +840,7 @@ class LLM:
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
@ -915,6 +922,22 @@ class LLM:
if truncate_prompt_tokens is not None:
param.truncate_prompt_tokens = truncate_prompt_tokens
io_processor_prompt = False
if isinstance(prompts, dict) and "data" in prompts:
io_processor_prompt = True
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
# Validate the request data is valid for the loaded plugin
validated_prompt = self.io_processor.parse_request(prompts)
# obtain the actual model prompts from the pre-processor
prompts = self.io_processor.pre_process(prompt=validated_prompt)
self._validate_and_add_requests(
prompts=prompts,
params=pooling_params,
@ -923,8 +946,24 @@ class LLM:
)
outputs = self._run_engine(use_tqdm=use_tqdm)
return self.engine_class.validate_outputs(outputs,
PoolingRequestOutput)
model_outputs = self.engine_class.validate_outputs(
outputs, PoolingRequestOutput)
if io_processor_prompt:
# get the post-processed model outputs
assert self.io_processor is not None
processed_outputs = self.io_processor.post_process(
model_output=model_outputs)
return [
PoolingRequestOutput[Any](request_id="",
outputs=processed_outputs,
prompt_token_ids=[],
finished=True)
]
else:
return model_outputs
def embed(
self,

View File

@ -64,6 +64,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse,
IOProcessorResponse,
LoadLoRAAdapterRequest,
PoolingRequest, PoolingResponse,
RerankRequest, RerankResponse,
@ -795,7 +796,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.error.code)
elif isinstance(generator, PoolingResponse):
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
return JSONResponse(content=generator.model_dump())
assert_never(generator)
@ -1782,7 +1783,7 @@ async def init_app_state(
) if "generate" in supported_tasks else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
vllm_config,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,

View File

@ -6,7 +6,8 @@
import json
import time
from http import HTTPStatus
from typing import Annotated, Any, ClassVar, Literal, Optional, Union
from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional,
TypeVar, Union)
import regex as re
import torch
@ -1405,7 +1406,46 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
PoolingCompletionRequest = EmbeddingCompletionRequest
PoolingChatRequest = EmbeddingChatRequest
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
T = TypeVar("T")
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
model: Optional[str] = None
priority: int = Field(default=0)
"""
The priority of the request (lower means earlier handling;
default: 0). Any priority other than 0 will raise an error
if the served model does not use priority scheduling.
"""
data: T
"""
When using plugins IOProcessor plugins, the actual input is processed
by the plugin itself. Hence, we use a generic type for the request data
"""
def to_pooling_params(self):
return PoolingParams(task="encode")
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
request_id: Optional[str] = None
"""
The request_id associated with this response
"""
created_at: int = Field(default_factory=lambda: int(time.time()))
data: T
"""
When using plugins IOProcessor plugins, the actual output is generated
by the plugin itself. Hence, we use a generic type for the response data
"""
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest,
IOProcessorRequest]
class ScoreRequest(OpenAIBaseModel):

View File

@ -49,9 +49,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
EmbeddingCompletionRequest,
EmbeddingRequest,
EmbeddingResponse, ErrorInfo,
ErrorResponse, PoolingResponse,
RerankRequest, ResponsesRequest,
ScoreRequest, ScoreResponse,
ErrorResponse,
IOProcessorRequest,
PoolingResponse, RerankRequest,
ResponsesRequest, ScoreRequest,
ScoreResponse,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeResponse,
@ -89,7 +91,7 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
TokenizeChatRequest]
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
ResponsesRequest]
ResponsesRequest, IOProcessorRequest]
AnyResponse = Union[
CompletionResponse,

View File

@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import Final, Literal, Optional, Union, cast
import jinja2
@ -13,19 +13,25 @@ import torch
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
vllm_config: VllmConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
model_config=vllm_config.model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
if isinstance(request, PoolingChatRequest):
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
else:
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
request.input,
add_special_tokens=request.add_special_tokens,
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id,
created_time,
model_name,
encoding_format,
request.encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")

View File

@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
"""
Extract tool calls for streaming mode.
"""
# Simplify detection: if it begins with "[" treat it as a function call
is_function_call = (current_text.strip().startswith("["))
# First, check for a definitive start of a tool call block.
# This prevents premature parsing of incomplete output.
stripped_text = current_text.strip()
preprocessed_content, preprocessed_tool_calls = (
self.preprocess_model_output(current_text))
# If not a function call, return normal content
if not is_function_call:
# For JSON code blocks, we need to detect them earlier, even if incomplete
has_potential_json_block = ("```json" in current_text
or "```\n[" in current_text
or "[TOOL_CALLS]" in current_text
or "<tool_call>" in current_text)
is_tool_call_block = (
stripped_text.startswith("[")
or stripped_text.startswith("<tool_call>")
or stripped_text.startswith("[TOOL_CALLS]") or
# Check if we have thinking tags with JSON-like content following
("</think>[" in current_text) or
# Check if the text contains a JSON array after preprocessing
preprocessed_tool_calls is not None or
# For JSON code blocks, detect early if we see enough structure
(has_potential_json_block and '"name"' in current_text
and '"arguments"' in current_text))
if not is_tool_call_block:
return DeltaMessage(content=delta_text)
try:
@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
# Try parsing as JSON to check for complete tool calls
try:
parsed_tools = json.loads(current_text)
# Use preprocessed tool calls if available
tool_calls_text = (preprocessed_tool_calls if
preprocessed_tool_calls else current_text)
parsed_tools = json.loads(tool_calls_text)
if isinstance(parsed_tools, list):
# Update our tool array for next time
self.prev_tool_call_arr = parsed_tools
@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
return delta
# Use regex to identify tool calls in the output
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
search_text = (preprocessed_tool_calls
if preprocessed_tool_calls else current_text)
# For JSON code blocks that aren't complete yet, try to extract the JSON content
if not preprocessed_tool_calls and has_potential_json_block:
# Try to extract the JSON array from within the code block
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
current_text)
if json_match:
potential_json = json_match.group(1).strip()
# Use this as search text even if it's incomplete
if potential_json.startswith("[") and (
'"name"' in potential_json
and '"arguments"' in potential_json):
search_text = potential_json
# Try to find complete tool names first
name_pattern = r'"name"\s*:\s*"([^"]+)"'
name_matches = list(re.finditer(name_pattern, current_text))
name_matches = list(re.finditer(name_pattern, search_text))
tool_count = len(name_matches)
# If no tools found yet, return
# If no complete tool names found, check for partial tool names
if tool_count == 0:
return None
# Check if we're in the middle of parsing a tool name
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
partial_matches = list(
re.finditer(partial_name_pattern, search_text))
if partial_matches:
# We have a partial tool name - not ready to emit yet
return None
else:
# No tools found at all
return None
# Ensure our state arrays are large enough
while len(self.streaming_state["sent_tools"]) < tool_count:
@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
# First, check for the empty arguments case: "arguments": {}
empty_args_pattern = (
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
empty_args_match = re.search(empty_args_pattern, current_text)
empty_args_match = re.search(empty_args_pattern, search_text)
# Check if this tool has empty arguments
if empty_args_match and empty_args_match.start() > 0:
@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
# Extract arguments for current tool using regex for non-empty arguments
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
args_matches = list(re.finditer(args_pattern, current_text))
args_matches = list(re.finditer(args_pattern, search_text))
if current_idx < len(args_matches):
args_text = args_matches[current_idx].group(1)
@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
# Handle transition between tools
is_last_tool = current_idx == tool_count - 1
# Find where the arguments for our current tool end
if not is_last_tool:
# If we have more tools after this one, try to find the complete argument block
next_tool_pos = current_text.find(
"},{", args_matches[current_idx].start())
if next_tool_pos != -1:
args_end_pos = (next_tool_pos + 1
) # +1 to include the '}'
args_text = (current_text[args_matches[current_idx]
.start():args_end_pos].
split('"arguments":')[1].strip())
# For multiple tools, extract only the arguments for the current tool
if tool_count > 1:
# Parse the entire JSON structure to properly extract arguments for each tool
try:
parsed_tools = json.loads(search_text)
if isinstance(
parsed_tools,
list) and current_idx < len(parsed_tools):
current_tool = parsed_tools[current_idx]
if isinstance(current_tool.get("arguments"),
dict):
args_text = json.dumps(
current_tool["arguments"])
else:
args_text = str(
current_tool.get("arguments", "{}"))
except (json.JSONDecodeError, KeyError, IndexError):
# Fallback to regex-based extraction
pass
# If arguments haven't been sent yet
sent_args = self.streaming_state["sent_tools"][
@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser):
index=current_idx,
function=DeltaFunctionCall(
arguments="{").model_dump(
exclude_none=True), # type: ignore
exclude_none=True), # type: ignore
)
])
return delta

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
@ -18,6 +18,7 @@ target model.
"""
__all__ = [
"DataPrompt",
"TextPrompt",
"TokensPrompt",
"PromptType",

View File

@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
"""
class DataPrompt(TypedDict):
"""Represents generic inputs handled by IO processor plugins."""
data: Any
"""The input data"""
data_format: str
"""The input data format"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:

View File

@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig):
desc_act: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
autoround_version: str = "",
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig):
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
default="")
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic)
dynamic, autoround_version)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str

View File

@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
# used to identify GPTQ model quantized by autoround
self.autoround_version = full_config.get("autoround_version", "")
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size}, "

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import deepcopy
from fractions import Fraction
from typing import Optional, Union
import regex as re
@ -29,7 +30,7 @@ def override_config(config: QuantizationConfig, prefix: str):
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):

View File

@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration(
raise ValueError("Only audio modality is supported")
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config

View File

@ -1371,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
output_tensor[i, :t.size(0)] = t
return output_tensor
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
# - list[torch.Tensor]:

View File

@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
inducator_tokens: torch.Tensor
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`

View File

@ -3,7 +3,7 @@
""" PyTorch Ovis model."""
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Optional, Union
from typing import Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
}
class OvisVideoPatchInputs(TypedDict):
type: Literal["video_patches"]
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
patches_per_image: list[int]
"""
List of number of total patches for each frame in the video.
This is used to restore the first two dimensions of `flat_data`.
"""
def _ovis2_5_field_config():
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
self.make_empty_intermediate_tensors = (
self.get_language_model().make_empty_intermediate_tensors)
def _parse_and_validate_visual_input(
self, is_video,
**kwargs: object) -> Optional[OvisImagePatchInputs]:
if is_video:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
else:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
if pixel_values is None and indicator_tokens is None:
return None
@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
if pixel_values is None and indicator_tokens is None:
return None
if pixel_values is not None and indicator_tokens is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(indicator_tokens, (torch.Tensor, list)):
raise ValueError("Incorrect type of indicator_tokens. "
f"Got type: {type(indicator_tokens)}")
return OvisVideoPatchInputs(
type="video_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
patches_per_image=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
torch.cat(vision_embeddings_per_image, dim=0))
return tuple(vision_embeddings)
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
embeddings = []
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = {}
# NOTE: _parse_and_validate_visual_input has side-effects and pops
# keys from kwargs. We process images first, then videos.
image_input = self._parse_and_validate_visual_input(False, **kwargs)
if image_input:
embeddings.extend(self._process_image_input(image_input))
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("pixel_values", "indicator_tokens",
"grids") and "images" not in modalities:
modalities["images"] = self._parse_and_validate_image_input(
**kwargs)
if input_key in ("video_pixel_values", "video_indicator_tokens",
"video_grids") and "videos" not in modalities:
modalities["videos"] = self._parse_and_validate_video_input(
**kwargs)
video_input = self._parse_and_validate_visual_input(True, **kwargs)
if video_input:
embeddings.extend(self._process_image_input(video_input))
return modalities
return tuple(embeddings) if embeddings else None
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
if not modalities:
return []
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings
def get_input_embeddings(
self,

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module):
return loaded_params
class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # may be different per batch and image
]
image_sizes: Annotated[
torch.Tensor,
TensorShape("bn", 2), # (height, width)
]
num_img_tokens: Annotated[
list[int],
TensorShape("bn"),
]
image_attention_mask: Annotated[
torch.Tensor,
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
]
class Phi4MMImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match language model backbone)
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
num_img_tokens: list[int]
"""Shape: `(batch_size * num_images)`"""
image_attention_mask: torch.Tensor
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
class Phi4MMImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "f", "h"),
]
class Phi4MMAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- f: Number of Mel filterbank bins (80)
- t: Time frames (M)
"""
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
]
class Phi4MMAudioEmbeddingInputs(TypedDict):
class Phi4MMAudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
data: Annotated[
NestedTensors,
TensorShape("b", "n", "f", "h"),
]
Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
data=flatten_bn(audio_features))
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs(
type="pixel_values",

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union
import numpy as np
import torch
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .idefics2_vision_model import Idefics2VisionTransformer
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module):
return img_set_tensor
class Phi4MMImagePixelInputs(TypedDict):
class Phi4MMImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- p: Number of patches (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
- nc: Number of crops
- H_mask: Height of attention mask
- W_mask: Width of attention mask
"""
type: Literal["pixel_values"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape:
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # may be different per batch and image
]
image_sizes: Annotated[
torch.Tensor,
TensorShape("bn", 2), # (height, width)
]
num_img_tokens: Annotated[
list[int],
TensorShape("bn"),
]
image_attention_mask: Annotated[
torch.Tensor,
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
]
class Phi4MMAudioFeatureInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- t: Time frames (M)
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size * num_images, 2)`
This should be in `(height, width)` format.
"""
num_img_tokens: list[int]
"""Shape: `(batch_size * num_images)`"""
image_attention_mask: torch.Tensor
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
class Phi4MMAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
data: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size * num_audios, 80, M)"""
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
]
class Phi4MMAudioEmbeddingInputs(TypedDict):
class Phi4MMAudioEmbeddingInputs(TensorSchema):
"""
Dimensions:
- b: Batch size
- n: Number of audios
- f: Audio feature size
- h: Hidden size (must match language model backbone)
"""
type: Literal["audio_embeds"]
data: NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
data: Annotated[
NestedTensors,
TensorShape("b", "n", "f", "h"),
]
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
return None
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(audio_features)}")
return Phi4MMAudioFeatureInputs(type="audio_features",
data=flatten_bn(audio_features))
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio embeds. "
f"Got type: {type(audio_embeds)}")
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
data=audio_embeds)
@ -1031,8 +1054,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
]
return audio_embeds
def _parse_and_validate_image_input(self,
**kwargs: object) -> Optional[dict]:
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
if input_image_embeds is None:
return None
@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(image_sizes, torch.Tensor):
image_sizes = image_sizes.flatten(0, 1)
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect image_sizes inputs")
if isinstance(num_img_tokens, list):
num_img_tokens = [
@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
elif isinstance(num_img_tokens, torch.Tensor):
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
else:
raise ValueError("Incorrect image_attention_mask inputs")
raise ValueError("Incorrect num_img_tokens inputs")
return Phi4MMImagePixelInputs(
type="pixel_values",

View File

@ -159,9 +159,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
# seems to avoid gate quantization while AutoRound does.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
if isinstance(
quant_config,
(GPTQConfig,
GPTQMarlinConfig)) and not quant_config.autoround_version:
return None
return quant_config

View File

@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import logging
from typing import Optional
from vllm.config import VllmConfig
from vllm.plugins import load_plugins_by_group
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.utils import resolve_obj_by_qualname
logger = logging.getLogger(__name__)
def get_io_processor(
vllm_config: VllmConfig,
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
# Input.Output processors are loaded as plugins under the
# 'vllm.io_processor_plugins' group. Similar to platform
# plugins, these plugins register a function that returns the class
# name for the processor to install.
if plugin_from_init:
model_plugin = plugin_from_init
else:
# A plugin can be specified via the model config
# Retrieve the model specific plugin if available
# This is using a custom field in the hf_config for the model
hf_config = vllm_config.model_config.hf_config.to_dict()
config_plugin = hf_config.get("io_processor_plugin")
model_plugin = config_plugin
if model_plugin is None:
logger.info("No IOProcessor plugins requested by the model")
return None
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
# Load all installed plugin in the group
multimodal_data_processor_plugins = \
load_plugins_by_group('vllm.io_processor_plugins')
loadable_plugins = {}
for name, func in multimodal_data_processor_plugins.items():
try:
assert callable(func)
processor_cls_qualname = func()
if processor_cls_qualname is not None:
loadable_plugins[name] = processor_cls_qualname
except Exception:
logger.warning("Failed to load plugin %s.", name, exc_info=True)
num_available_plugins = len(loadable_plugins.keys())
if num_available_plugins == 0:
raise ValueError("No IOProcessor plugins installed"
f" but one is required ({model_plugin}).")
if model_plugin not in loadable_plugins:
raise ValueError(
f"The model requires the '{model_plugin}' IO Processor plugin "
"but it is not installed. "
f"Available plugins: {list(loadable_plugins.keys())}")
activated_plugin_cls = loadable_plugins[model_plugin]
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)

View File

@ -0,0 +1,62 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Sequence
from typing import Any, Generic, Optional, TypeVar, Union
from vllm.config import VllmConfig
from vllm.entrypoints.openai.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
IOProcessorInput = TypeVar('IOProcessorInput')
IOProcessorOutput = TypeVar('IOProcessorOutput')
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
@abstractmethod
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
raise NotImplementedError
async def pre_process_async(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
return self.pre_process(prompt, request_id, **kwargs)
@abstractmethod
def post_process(self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
**kwargs) -> IOProcessorOutput:
raise NotImplementedError
async def post_process_async(
self,
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
request_id: Optional[str] = None,
**kwargs,
) -> IOProcessorOutput:
collected_output = [item async for i, item in model_output]
return self.post_process(collected_output, request_id, **kwargs)
@abstractmethod
def parse_request(self, request: Any) -> IOProcessorInput:
raise NotImplementedError
@abstractmethod
def output_to_response(
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
raise NotImplementedError

View File

@ -3290,7 +3290,7 @@ def sha256_cbor_64bit(input) -> int:
return full_hash & ((1 << 64) - 1)
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:

View File

@ -4,8 +4,9 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
BlockRemoved, BlockStored,
KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock)
@ -156,6 +157,7 @@ class BlockPool:
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
medium=MEDIUM_GPU,
))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
@ -218,7 +220,8 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
medium=MEDIUM_GPU))
return True
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:

View File

@ -527,6 +527,7 @@ def hash_block_tokens(
hash values for the same block contents.
Args:
hash_function: The hash function used to compute block hash.
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A list of token ids in the current

View File

@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
@ -1207,7 +1219,7 @@ class Scheduler(SchedulerInterface):
# Now that the blocks are ready, actually cache them.
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
num_computed_tokens = len(block_ids) * self.block_size
# Handle the case where num request tokens less then one block.
# Handle the case where num request tokens less than one block.
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
if num_computed_tokens == request.num_tokens:
num_computed_tokens -= 1

View File

@ -47,7 +47,7 @@ class SingleTypeKVCacheManager(ABC):
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
# data for preempted ones.
self.num_cached_block: dict[str, int] = {}
self.kv_cache_group_id = kv_cache_group_id

View File

@ -138,14 +138,14 @@ def _torch_cuda_wrapper():
@contextmanager
def _set_global_compilation_settings(config: VllmConfig):
import torch._inductor.config
import torch._inductor.config as torch_inductor_config
inductor_config = config.compilation_config.inductor_compile_config
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
freezing_value = torch._inductor.config.freezing
freezing_value = torch_inductor_config.freezing
try:
if inductor_config.get("max_autotune", False):
torch._inductor.config.freezing = True
torch_inductor_config.freezing = True
yield
finally:
torch._inductor.config.freezing = freezing_value
torch_inductor_config.freezing = freezing_value

View File

@ -600,28 +600,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits_indices_padded = None
if self.cache_config.kv_sharing_fast_prefill:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(
num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
)
# Used in the below loop.
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
@ -644,13 +624,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
blk_table_tensor = torch.zeros(
(num_reqs, 1),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device, non_blocking=True)
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
dtype=torch.int32,
pin_memory=self.pin_memory,
device="cpu").to(self.device,
non_blocking=True)
device=self.device,
)
slot_mapping = torch.zeros(
(total_num_scheduled_tokens, ),
dtype=torch.int64,
device=self.device,
)
num_common_prefix_blocks = 0
else:
blk_table_tensor = block_tables[kv_cache_group_id]
@ -889,6 +869,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mrope_pos_ptr += completion_part_len
def _prepare_kv_sharing_fast_prefill(
self,
logits_indices: torch.Tensor,
) -> torch.Tensor:
assert self.kv_sharing_fast_prefill_logits_indices is not None
num_logits = logits_indices.shape[0]
assert num_logits > 0
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
logits_indices)
# There might have leftover indices in logits_indices[num_logits:]
# from previous iterations, whose values may be greater than the
# batch size in the current iteration. To ensure indices are always
# valid, we fill the padded indices with the last index.
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
logits_indices[-1].item())
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
and num_logits <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
else:
num_logits_padded = num_logits
logits_indices_padded = (
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
return logits_indices_padded
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs: