mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 00:17:05 +08:00
merge
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
commit
4c2a337e67
@ -566,8 +566,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing/test_tensor_schema.py
|
||||
- pytest -v -s models/multimodal/processing
|
||||
|
||||
- label: Multi-Modal Models Test (Standard)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -770,6 +769,11 @@ steps:
|
||||
- pytest -v -s plugins_tests/test_platform_plugins.py
|
||||
- pip uninstall vllm_add_dummy_platform -y
|
||||
# end platform plugin tests
|
||||
# begin io_processor plugins test, all the code in between uses the prithvi_io_processor plugin
|
||||
- pip install -e ./plugins/prithvi_io_processor_plugin
|
||||
- pytest -v -s plugins_tests/test_io_processor_plugins.py
|
||||
- pip uninstall prithvi_io_processor_plugin -y
|
||||
# end io_processor plugins test
|
||||
# other tests continue here:
|
||||
- pytest -v -s plugins_tests/test_scheduler_plugins.py
|
||||
- pip install -e ./plugins/vllm_add_dummy_model
|
||||
|
||||
78
docs/design/io_processor_plugins.md
Normal file
78
docs/design/io_processor_plugins.md
Normal file
@ -0,0 +1,78 @@
|
||||
# IO Processor Plugins
|
||||
|
||||
IO Processor plugins are a feature that allows pre and post processing of the model input and output for pooling models. The idea is that users are allowed to pass a custom input to vLLM that is converted into one or more model prompts and fed to the model `encode` method. One potential use-case of such plugins is that of using vLLM for generating multi-modal data. Say users feed an image to vLLM and get an image in output.
|
||||
|
||||
When performing an inference with IO Processor plugins, the prompt type is defined by the plugin and the same is valid for the final request output. vLLM does not perform any validation of input/output data, and it is up to the plugin to ensure the correct data is being fed to the model and returned to the user. As of now these plugins support only pooling models and can be triggerd via the `encode` method in `LLM` and `AsyncLLM`, or in online serving mode via the `/pooling` endpoint.
|
||||
|
||||
## Writing an IO Processor Plugin
|
||||
|
||||
IO Processor plugins implement the `IOProcessor` interface (<gh-file:vllm/plugins/io_processors/interface.py>):
|
||||
|
||||
```python
|
||||
IOProcessorInput = TypeVar('IOProcessorInput')
|
||||
IOProcessorOutput = TypeVar('IOProcessorOutput')
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def post_process(self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs) -> IOProcessorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
raise NotImplementedError
|
||||
```
|
||||
|
||||
The `parse_request` method is used for validating the user prompt and converting it into the input expected by the `pre_process`/`pre_process_async` methods.
|
||||
The `pre_process*` methods take the validated plugin input to generate vLLM's model prompts for regular inference.
|
||||
The `post_process*` methods take `PoolingRequestOutput` objects as input and generate a custom plugin output.
|
||||
|
||||
The `output_to_response` method is used only for online serving and converts the plugin output to the `IOProcessorResponse` type that is then returned by the API Server. The implementation of the `/io_processor_pooling` serving endpoint is [here](../../vllm/entrypoints/openai/serving_pooling_with_io_plugin.py).
|
||||
|
||||
An example implementation of a plugin that enables generating geotiff images with the PrithviGeospatialMAE model is available [here](https://github.com/christian-pinto/prithvi_io_processor_plugin). Please, also refer to our [online](../../examples/online_serving/prithvi_geospatial_mae.py) and [offline](../../examples/offline_inference/prithvi_geospatial_mae_io_processor.py) inference examples.
|
||||
|
||||
## Using an IO Processor plugin
|
||||
|
||||
IO Processor plugins are loaded at engine startup and there are two methods for specifying the name of the plugin to be loaded:
|
||||
|
||||
1. Via vLLM's `EngineArgs`: setting the `io_processor_plugin` argument in the `EngineArgs` used to initialize the `AsyncLLM`. The same can be achieved by passing the `io_processor_plugin` argument to `LLM` in offline mode, or by passing the `--io-processor-plugin` argument in serving mode.
|
||||
2. Via the model HF configuration: adding an `io_processor_plugin` field to the model config (config.json).
|
||||
|
||||
The order also determines method priority. i.e., setting the plugin name via `EngineArgs` will override any plugin name specified in the model HF config (config.json).
|
||||
@ -49,6 +49,8 @@ Every plugin has three parts:
|
||||
|
||||
- **Platform plugins** (with group name `vllm.platform_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree platforms into vLLM. The plugin function should return `None` when the platform is not supported in the current environment, or the platform class's fully qualified name when the platform is supported.
|
||||
|
||||
- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for poling models. The plugin function returns the IOProcessor's class fully qualified name.
|
||||
|
||||
## Guidelines for Writing Plugins
|
||||
|
||||
- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
|
||||
|
||||
@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
# This example shows how to perform an offline inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Reuirement - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
llm = LLM(
|
||||
model="christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM.
|
||||
# The maximum number depends on the available GPU memory
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="prithvi_to_tiff_india",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
pooler_output = llm.encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
print(output)
|
||||
decoded_data = base64.b64decode(output.data)
|
||||
|
||||
file_path = os.path.join(os.getcwd(), "offline_prediction.tiff")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(decoded_data)
|
||||
|
||||
print(f"Output file path: {file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -27,10 +27,12 @@ class BlockStored(KVCacheEvent):
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
||||
54
examples/online_serving/prithvi_geospatial_mae.py
Normal file
54
examples/online_serving/prithvi_geospatial_mae.py
Normal file
@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import requests
|
||||
|
||||
# This example shows how to perform an online inference that generates
|
||||
# multimodal data. In this specific case this example will take a geotiff
|
||||
# image as input, process it using the multimodal data processor, and
|
||||
# perform inference.
|
||||
# Reuirements :
|
||||
# - install plugin at:
|
||||
# https://github.com/christian-pinto/prithvi_io_processor_plugin
|
||||
# - start vllm in serving mode with the below args
|
||||
# --model='christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM'
|
||||
# --task embed --trust-remote-code
|
||||
# --skip-tokenizer-init --enforce-eager
|
||||
# --io-processor-plugin prithvi_to_tiff_india
|
||||
|
||||
|
||||
def main():
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/India_900498_S2Hand.tif" # noqa: E501
|
||||
server_endpoint = "http://localhost:8000/pooling"
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM",
|
||||
}
|
||||
|
||||
ret = requests.post(server_endpoint, json=request_payload_url)
|
||||
|
||||
print(f"response.status_code: {ret.status_code}")
|
||||
print(f"response.reason:{ret.reason}")
|
||||
|
||||
response = ret.json()
|
||||
|
||||
decoded_image = base64.b64decode(response["data"]["data"])
|
||||
|
||||
out_path = os.path.join(os.getcwd(), "online_prediction.tiff")
|
||||
|
||||
with open(out_path, "wb") as f:
|
||||
f.write(decoded_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1120,6 +1120,9 @@ class VllmRunner:
|
||||
|
||||
return self.llm.llm_engine.collective_rpc(_apply_model)
|
||||
|
||||
def get_llm(self) -> LLM:
|
||||
return self.llm
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
@ -1,30 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from typing import Any, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch.nn as nn
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor
|
||||
from vllm.multimodal.utils import group_mm_kwargs_by_modality
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.utils import GiB_bytes, is_list_of, set_default_torch_num_threads
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.engine.core import EngineCore as V1EngineCore
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from ....conftest import VllmRunner
|
||||
from ...registry import _MULTIMODAL_EXAMPLE_MODELS, HF_EXAMPLE_MODELS
|
||||
from ...utils import dummy_hf_overrides
|
||||
|
||||
@ -137,6 +138,27 @@ def create_batched_mm_kwargs(
|
||||
return group_mm_kwargs_by_modality(items)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
rank=0,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=0,
|
||||
backend="nccl",
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=1)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
with set_current_vllm_config(vllm_config=vllm_config):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
model = model_cls(vllm_config=vllm_config)
|
||||
yield model
|
||||
|
||||
del model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
def get_model_id_to_test(
|
||||
model_arch_list: Iterable[str]) -> list[tuple[str, str]]:
|
||||
filtered_results = []
|
||||
@ -155,8 +177,7 @@ def get_model_id_to_test(
|
||||
@pytest.mark.parametrize(
|
||||
"model_arch, model_id",
|
||||
get_model_id_to_test(_MULTIMODAL_EXAMPLE_MODELS.keys()))
|
||||
def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
vllm_runner: type[VllmRunner], monkeypatch):
|
||||
def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
if model_arch in ARCH_TO_SKIP:
|
||||
pytest.skip(f"Skipping {model_arch} due to {ARCH_TO_SKIP[model_arch]}")
|
||||
if model_id in REPO_ID_TO_SKIP:
|
||||
@ -177,14 +198,20 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
hf_overrides=model_info.hf_overrides,
|
||||
hf_overrides=hf_overrides_fn,
|
||||
)
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
|
||||
if not any(
|
||||
hasattr(model_cls, f"_parse_and_validate_{m}_input")
|
||||
for m in ["image", "video", "audio"]):
|
||||
inputs_parse_methods = []
|
||||
for attr_name in dir(model_cls):
|
||||
attr = getattr(model_cls, attr_name)
|
||||
if hasattr(attr, "__annotations__"):
|
||||
return_type = attr.__annotations__.get("return", None)
|
||||
if return_type is not None and "Input" in str(return_type):
|
||||
inputs_parse_methods.append(attr_name)
|
||||
|
||||
if not any(inputs_parse_methods):
|
||||
pytest.skip(f"{model_arch} does not support tensor schema validation.")
|
||||
|
||||
ctx = InputProcessingContext(
|
||||
@ -197,68 +224,13 @@ def test_model_tensor_schema(model_arch: str, model_id: str,
|
||||
modality: 3 if limit is None else limit
|
||||
for modality, limit in supported_mm_limits.items()
|
||||
}
|
||||
model_config.get_multimodal_config().limit_per_prompt = limit_mm_per_prompt
|
||||
processor = factories.build_processor(ctx, cache=None)
|
||||
|
||||
# Avoid calling model.forward()
|
||||
def _initialize_kv_caches_v0(self) -> None:
|
||||
self.cache_config.num_gpu_blocks = 0
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
def _initialize_kv_caches_v1(self, vllm_config):
|
||||
kv_cache_specs = self.model_executor.get_kv_cache_specs()
|
||||
scheduler_kv_cache_config = get_kv_cache_config(
|
||||
vllm_config,
|
||||
kv_cache_specs[0],
|
||||
10 * GiB_bytes,
|
||||
)
|
||||
|
||||
# gpu_blocks (> 0), cpu_blocks, scheduler_kv_cache_config
|
||||
return 1, 0, scheduler_kv_cache_config
|
||||
|
||||
with (patch.object(V0LLMEngine, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v0),
|
||||
patch.object(V1EngineCore, "_initialize_kv_caches",
|
||||
_initialize_kv_caches_v1), monkeypatch.context() as m):
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
if model_info.v0_only:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
|
||||
# TODO(Isotr0py): Can we avoid initializing engine?
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
model_id,
|
||||
tokenizer_name=model_info.tokenizer,
|
||||
tokenizer_mode=model_info.tokenizer_mode,
|
||||
revision=model_info.revision,
|
||||
trust_remote_code=model_info.trust_remote_code,
|
||||
max_model_len=model_info.max_model_len,
|
||||
load_format="dummy",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
enforce_eager=True,
|
||||
) as vllm_model,
|
||||
):
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
llm_engine = vllm_model.llm.llm_engine
|
||||
|
||||
if hasattr(llm_engine, "processor"):
|
||||
# v1 processor
|
||||
mm_registry = llm_engine.processor.mm_registry
|
||||
else:
|
||||
# v0 input_preprocessor
|
||||
mm_registry = llm_engine.input_preprocessor.mm_registry
|
||||
|
||||
processor = mm_registry.create_processor(model_config)
|
||||
|
||||
def validate_model_input(model, modality: str,
|
||||
mm_kwargs: MultiModalKwargs):
|
||||
method_name = f"_parse_and_validate_{modality}_input"
|
||||
if hasattr(model, method_name):
|
||||
getattr(model, method_name)(**mm_kwargs)
|
||||
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
valid_func = partial(validate_model_input,
|
||||
modality=modality,
|
||||
mm_kwargs=mm_kwargs)
|
||||
vllm_model.apply_model(valid_func)
|
||||
with initialize_dummy_model(model_cls, model_config) as model:
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
for method_name in inputs_parse_methods:
|
||||
print(f"Testing `{method_name}` with modality={modality} "
|
||||
f"and mm_kwargs{list(mm_kwargs.keys())}")
|
||||
getattr(model, method_name)(modality=modality, **mm_kwargs)
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
def register_prithvi_india():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorIndia" # noqa: E501
|
||||
|
||||
|
||||
def register_prithvi_valencia():
|
||||
return "prithvi_io_processor.prithvi_processor.PrithviMultimodalDataProcessorValencia" # noqa: E501
|
||||
@ -0,0 +1,449 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import albumentations
|
||||
import numpy as np
|
||||
import rasterio
|
||||
import regex as re
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import (IOProcessorRequest,
|
||||
IOProcessorResponse)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (IOProcessor,
|
||||
IOProcessorInput,
|
||||
IOProcessorOutput)
|
||||
|
||||
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
NO_DATA = -9999
|
||||
NO_DATA_FLOAT = 0.0001
|
||||
OFFSET = 0
|
||||
PERCENTILE = 99
|
||||
|
||||
DEFAULT_INPUT_INDICES = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
datamodule_config: DataModuleConfig = {
|
||||
"bands": ["BLUE", "GREEN", "RED", "NIR_NARROW", "SWIR_1", "SWIR_2"],
|
||||
"batch_size":
|
||||
16,
|
||||
"constant_scale":
|
||||
0.0001,
|
||||
"data_root":
|
||||
"/dccstor/geofm-finetuning/datasets/sen1floods11",
|
||||
"drop_last":
|
||||
True,
|
||||
"no_data_replace":
|
||||
0.0,
|
||||
"no_label_replace":
|
||||
-1,
|
||||
"num_workers":
|
||||
8,
|
||||
"test_transform": [
|
||||
albumentations.Resize(always_apply=False,
|
||||
height=448,
|
||||
interpolation=1,
|
||||
p=1,
|
||||
width=448),
|
||||
albumentations.pytorch.ToTensorV2(transpose_mask=False,
|
||||
always_apply=True,
|
||||
p=1.0),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def save_geotiff(image: torch.Tensor, meta: dict,
|
||||
out_format: str) -> str | bytes:
|
||||
"""Save multi-band image in Geotiff file.
|
||||
|
||||
Args:
|
||||
image: np.ndarray with shape (bands, height, width)
|
||||
output_path: path where to save the image
|
||||
meta: dict with meta info.
|
||||
"""
|
||||
if out_format == "path":
|
||||
# create temp file
|
||||
file_path = os.path.join(os.getcwd(), "prediction.tiff")
|
||||
with rasterio.open(file_path, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
return file_path
|
||||
elif out_format == "b64_json":
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
with rasterio.open(tmpfile.name, "w", **meta) as dest:
|
||||
for i in range(image.shape[0]):
|
||||
dest.write(image[i, :, :], i + 1)
|
||||
|
||||
file_data = tmpfile.read()
|
||||
return base64.b64encode(file_data)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown output format")
|
||||
|
||||
|
||||
def _convert_np_uint8(float_image: torch.Tensor):
|
||||
image = float_image.numpy() * 255.0
|
||||
image = image.astype(dtype=np.uint8)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def read_geotiff(
|
||||
file_path: Optional[str] = None,
|
||||
path_type: Optional[str] = None,
|
||||
file_data: Optional[bytes] = None,
|
||||
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
|
||||
"""Read all bands from *file_path* and return image + meta info.
|
||||
|
||||
Args:
|
||||
file_path: path to image file.
|
||||
|
||||
Returns:
|
||||
np.ndarray with shape (bands, height, width)
|
||||
meta info dict
|
||||
"""
|
||||
|
||||
if all([x is None for x in [file_path, path_type, file_data]]):
|
||||
raise Exception("All input fields to read_geotiff are None")
|
||||
write_to_file: Optional[bytes] = None
|
||||
path: Optional[str] = None
|
||||
if file_data is not None:
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(file_data)
|
||||
# path = tmpfile.name
|
||||
|
||||
write_to_file = file_data
|
||||
elif file_path is not None and path_type == "url":
|
||||
resp = urllib.request.urlopen(file_path)
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(resp.read())
|
||||
# path = tmpfile.name
|
||||
write_to_file = resp.read()
|
||||
elif file_path is not None and path_type == "path":
|
||||
path = file_path
|
||||
elif file_path is not None and path_type == "b64_json":
|
||||
image_data = base64.b64decode(file_path)
|
||||
# with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
# tmpfile.write(image_data)
|
||||
# path = tmpfile.name
|
||||
write_to_file = image_data
|
||||
else:
|
||||
raise Exception("Wrong combination of parameters to read_geotiff")
|
||||
|
||||
with tempfile.NamedTemporaryFile() as tmpfile:
|
||||
path_to_use = None
|
||||
if write_to_file:
|
||||
tmpfile.write(write_to_file)
|
||||
path_to_use = tmpfile.name
|
||||
elif path:
|
||||
path_to_use = path
|
||||
|
||||
with rasterio.open(path_to_use) as src:
|
||||
img = src.read()
|
||||
meta = src.meta
|
||||
try:
|
||||
coords = src.lnglat()
|
||||
except Exception:
|
||||
# Cannot read coords
|
||||
coords = None
|
||||
|
||||
return img, meta, coords
|
||||
|
||||
|
||||
def load_image(
|
||||
data: Union[list[str]],
|
||||
path_type: str,
|
||||
mean: Optional[list[float]] = None,
|
||||
std: Optional[list[float]] = None,
|
||||
indices: Optional[Union[list[int], None]] = None,
|
||||
):
|
||||
"""Build an input example by loading images in *file_paths*.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths .
|
||||
mean: list containing mean values for each band in the
|
||||
images in *file_paths*.
|
||||
std: list containing std values for each band in the
|
||||
images in *file_paths*.
|
||||
|
||||
Returns:
|
||||
np.array containing created example
|
||||
list of meta info for each image in *file_paths*
|
||||
"""
|
||||
|
||||
imgs = []
|
||||
metas = []
|
||||
temporal_coords = []
|
||||
location_coords = []
|
||||
|
||||
for file in data:
|
||||
# if isinstance(file, bytes):
|
||||
# img, meta, coords = read_geotiff(file_data=file)
|
||||
# else:
|
||||
img, meta, coords = read_geotiff(file_path=file, path_type=path_type)
|
||||
# Rescaling (don't normalize on nodata)
|
||||
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
||||
if indices is not None:
|
||||
img = img[..., indices]
|
||||
if mean is not None and std is not None:
|
||||
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
||||
|
||||
imgs.append(img)
|
||||
metas.append(meta)
|
||||
if coords is not None:
|
||||
location_coords.append(coords)
|
||||
|
||||
try:
|
||||
match = re.search(r"(\d{7,8}T\d{6})", file)
|
||||
if match:
|
||||
year = int(match.group(1)[:4])
|
||||
julian_day = match.group(1).split("T")[0][4:]
|
||||
if len(julian_day) == 3:
|
||||
julian_day = int(julian_day)
|
||||
else:
|
||||
julian_day = (datetime.datetime.strptime(
|
||||
julian_day, "%m%d").timetuple().tm_yday)
|
||||
temporal_coords.append([year, julian_day])
|
||||
except Exception:
|
||||
logger.exception("Could not extract timestamp for %s", file)
|
||||
|
||||
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
||||
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
||||
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.datamodule = Sen1Floods11NonGeoDataModule(
|
||||
data_root=datamodule_config["data_root"],
|
||||
batch_size=datamodule_config["batch_size"],
|
||||
num_workers=datamodule_config["num_workers"],
|
||||
bands=datamodule_config["bands"],
|
||||
drop_last=datamodule_config["drop_last"],
|
||||
test_transform=datamodule_config["test_transform"],
|
||||
)
|
||||
self.img_size = 512
|
||||
self.h1 = 1
|
||||
self.w1 = 1
|
||||
self.original_h = 512
|
||||
self.original_w = 512
|
||||
self.batch_size = 1
|
||||
self.meta_data = None
|
||||
self.requests_cache: dict[str, dict[str, Any]] = {}
|
||||
self.indices = DEFAULT_INPUT_INDICES
|
||||
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
if type(request) is dict:
|
||||
image_prompt = ImagePrompt(**request)
|
||||
return image_prompt
|
||||
if isinstance(request, IOProcessorRequest):
|
||||
if not hasattr(request, "data"):
|
||||
raise ValueError(
|
||||
"missing 'data' field in OpenAIBaseModel Request")
|
||||
|
||||
request_data = request.data
|
||||
|
||||
if type(request_data) is dict:
|
||||
return ImagePrompt(**request_data)
|
||||
else:
|
||||
raise ValueError("Unable to parse the request data")
|
||||
|
||||
raise ValueError("Unable to parse request")
|
||||
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
return IOProcessorResponse(
|
||||
request_id=plugin_output.request_id,
|
||||
data=plugin_output,
|
||||
)
|
||||
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
|
||||
image_data = dict(prompt)
|
||||
|
||||
if request_id:
|
||||
self.requests_cache[request_id] = {
|
||||
"out_format": image_data["out_data_format"],
|
||||
}
|
||||
|
||||
input_data, temporal_coords, location_coords, meta_data = load_image(
|
||||
data=[image_data["data"]],
|
||||
indices=self.indices,
|
||||
path_type=image_data["data_format"],
|
||||
)
|
||||
|
||||
self.meta_data = meta_data[0]
|
||||
|
||||
if input_data.mean() > 1:
|
||||
input_data = input_data / 10000 # Convert to range 0-1
|
||||
|
||||
self.original_h, self.original_w = input_data.shape[-2:]
|
||||
pad_h = (self.img_size -
|
||||
(self.original_h % self.img_size)) % self.img_size
|
||||
pad_w = (self.img_size -
|
||||
(self.original_w % self.img_size)) % self.img_size
|
||||
input_data = np.pad(
|
||||
input_data,
|
||||
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
||||
mode="reflect",
|
||||
)
|
||||
|
||||
batch = torch.tensor(input_data)
|
||||
windows = batch.unfold(3, self.img_size,
|
||||
self.img_size).unfold(4, self.img_size,
|
||||
self.img_size)
|
||||
self.h1, self.w1 = windows.shape[3:5]
|
||||
windows = rearrange(
|
||||
windows,
|
||||
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
||||
h=self.img_size,
|
||||
w=self.img_size,
|
||||
)
|
||||
|
||||
# Split into batches if number of windows > batch_size
|
||||
num_batches = (windows.shape[0] // self.batch_size
|
||||
if windows.shape[0] > self.batch_size else 1)
|
||||
windows = torch.tensor_split(windows, num_batches, dim=0)
|
||||
|
||||
if temporal_coords:
|
||||
temporal_coords = torch.tensor(temporal_coords).unsqueeze(0)
|
||||
else:
|
||||
temporal_coords = None
|
||||
if location_coords:
|
||||
location_coords = torch.tensor(location_coords[0]).unsqueeze(0)
|
||||
else:
|
||||
location_coords = None
|
||||
|
||||
prompts = []
|
||||
for window in windows:
|
||||
# Apply standardization
|
||||
window = self.datamodule.test_transform(
|
||||
image=window.squeeze().numpy().transpose(1, 2, 0))
|
||||
window = self.datamodule.aug(window)["image"]
|
||||
prompts.append({
|
||||
"prompt_token_ids": [1],
|
||||
"multi_modal_data": {
|
||||
"pixel_values": window.to(torch.float16)[0],
|
||||
"location_coords": location_coords.to(torch.float16),
|
||||
},
|
||||
})
|
||||
|
||||
return prompts
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
|
||||
pred_imgs_list = []
|
||||
|
||||
if request_id and (request_id in self.requests_cache):
|
||||
out_format = self.requests_cache[request_id]["out_format"]
|
||||
else:
|
||||
out_format = "b64_json"
|
||||
|
||||
for output in model_output:
|
||||
y_hat = output.outputs.data.argmax(dim=1)
|
||||
pred = torch.nn.functional.interpolate(
|
||||
y_hat.unsqueeze(1).float(),
|
||||
size=self.img_size,
|
||||
mode="nearest",
|
||||
)
|
||||
pred_imgs_list.append(pred)
|
||||
|
||||
pred_imgs: torch.Tensor = torch.concat(pred_imgs_list, dim=0)
|
||||
|
||||
# Build images from patches
|
||||
pred_imgs = rearrange(
|
||||
pred_imgs,
|
||||
"(b h1 w1) c h w -> b c (h1 h) (w1 w)",
|
||||
h=self.img_size,
|
||||
w=self.img_size,
|
||||
b=1,
|
||||
c=1,
|
||||
h1=self.h1,
|
||||
w1=self.w1,
|
||||
)
|
||||
|
||||
# Cut padded area back to original size
|
||||
pred_imgs = pred_imgs[..., :self.original_h, :self.original_w]
|
||||
|
||||
# Squeeze (batch size 1)
|
||||
pred_imgs = pred_imgs[0]
|
||||
|
||||
if not self.meta_data:
|
||||
raise ValueError("No metadata available for the current task")
|
||||
self.meta_data.update(count=1, dtype="uint8", compress="lzw", nodata=0)
|
||||
out_data = save_geotiff(_convert_np_uint8(pred_imgs), self.meta_data,
|
||||
out_format)
|
||||
|
||||
return ImageRequestOutput(type=out_format,
|
||||
format="tiff",
|
||||
data=out_data,
|
||||
request_id=request_id)
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorIndia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [1, 2, 3, 8, 11, 12]
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessorValencia(PrithviMultimodalDataProcessor):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.indices = [0, 1, 2, 3, 4, 5]
|
||||
@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
|
||||
import albumentations
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DataModuleConfig(TypedDict):
|
||||
bands: list[str]
|
||||
batch_size: int
|
||||
constant_scale: float
|
||||
data_root: str
|
||||
drop_last: bool
|
||||
no_data_replace: float
|
||||
no_label_replace: int
|
||||
num_workers: int
|
||||
test_transform: list[
|
||||
albumentations.core.transforms_interface.BasicTransform]
|
||||
|
||||
|
||||
class ImagePrompt(BaseModel):
|
||||
|
||||
data_format: Literal["b64_json", "bytes", "url"]
|
||||
"""
|
||||
This is the data type for the input image
|
||||
"""
|
||||
|
||||
image_format: str
|
||||
"""
|
||||
This is the image format (e.g., jpeg, png, etc.)
|
||||
"""
|
||||
|
||||
out_data_format: Literal["b64_json", "url"]
|
||||
|
||||
data: Any
|
||||
"""
|
||||
Input image data
|
||||
"""
|
||||
|
||||
|
||||
MultiModalPromptType = Union[ImagePrompt]
|
||||
|
||||
|
||||
class ImageRequestOutput(BaseModel):
|
||||
"""
|
||||
The output data of an image request to vLLM.
|
||||
|
||||
Args:
|
||||
type (str): The data content type [path, object]
|
||||
format (str): The image format (e.g., jpeg, png, etc.)
|
||||
data (Any): The resulting data.
|
||||
"""
|
||||
|
||||
type: Literal["path", "b64_json"]
|
||||
format: str
|
||||
data: str
|
||||
request_id: Optional[str] = None
|
||||
16
tests/plugins/prithvi_io_processor_plugin/setup.py
Normal file
16
tests/plugins/prithvi_io_processor_plugin/setup.py
Normal file
@ -0,0 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name="prithvi_io_processor_plugin",
|
||||
version="0.1",
|
||||
packages=["prithvi_io_processor"],
|
||||
entry_points={
|
||||
"vllm.io_processor_plugins": [
|
||||
"prithvi_to_tiff_india = prithvi_io_processor:register_prithvi_india", # noqa: E501
|
||||
"prithvi_to_tiff_valencia = prithvi_io_processor:register_prithvi_valencia", # noqa: E501
|
||||
]
|
||||
},
|
||||
)
|
||||
@ -1,12 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
|
||||
|
||||
def test_loading_engine_with_wrong_plugin():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LLM(
|
||||
model=MODEL_NAME,
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="wrong_plugin",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--skip-tokenizer-init",
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff_valencia"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_prithvi_mae_plugin_online(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
):
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload_url,
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response))
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
plugin_data = parsed_response.data
|
||||
|
||||
assert all(
|
||||
plugin_data.get(attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(plugin_data["data"])
|
||||
@ -7,6 +7,15 @@ import torch
|
||||
from vllm.plugins import load_general_plugins
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch):
|
||||
"""
|
||||
Since this module is V0 only, set VLLM_USE_V1=0 for
|
||||
all tests in the module.
|
||||
"""
|
||||
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||
|
||||
|
||||
def test_platform_plugins():
|
||||
# simulate workload by running an example
|
||||
import runpy
|
||||
|
||||
@ -2,12 +2,17 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage, FunctionCall,
|
||||
ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers import xLAMToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.transformers_utils.detokenizer import detokenize_incrementally
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
|
||||
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "Salesforce/Llama-xLAM-2-8B-fc-r"
|
||||
@ -36,6 +41,56 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
assert actual_tool_call.function == expected_tool_call.function
|
||||
|
||||
|
||||
def stream_delta_message_generator(
|
||||
xlam_tool_parser: xLAMToolParser,
|
||||
xlam_tokenizer: AnyTokenizer,
|
||||
model_output: str,
|
||||
request: Optional[ChatCompletionRequest] = None,
|
||||
) -> Generator[DeltaMessage, None, None]:
|
||||
all_token_ids = xlam_tokenizer.encode(model_output,
|
||||
add_special_tokens=False)
|
||||
|
||||
previous_text = ""
|
||||
previous_tokens = None
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
for i, delta_token in enumerate(all_token_ids):
|
||||
delta_token_ids = [delta_token]
|
||||
previous_token_ids = all_token_ids[:i]
|
||||
current_token_ids = all_token_ids[:i + 1]
|
||||
|
||||
(new_tokens, delta_text, new_prefix_offset,
|
||||
new_read_offset) = (detokenize_incrementally(
|
||||
tokenizer=xlam_tokenizer,
|
||||
all_input_ids=current_token_ids,
|
||||
prev_tokens=previous_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=False,
|
||||
spaces_between_special_tokens=True,
|
||||
))
|
||||
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
delta_message = xlam_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
request=request,
|
||||
)
|
||||
if delta_message:
|
||||
yield delta_message
|
||||
|
||||
previous_text = current_text
|
||||
previous_tokens = (previous_tokens +
|
||||
new_tokens if previous_tokens else new_tokens)
|
||||
prefix_offset = new_prefix_offset
|
||||
read_offset = new_read_offset
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = xlam_tool_parser.extract_tool_calls(
|
||||
@ -51,6 +106,7 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
@ -118,6 +174,20 @@ def test_extract_tool_calls_no_tools(xlam_tool_parser):
|
||||
],
|
||||
"I'll check the weather for you.",
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(xlam_tool_parser, model_output,
|
||||
@ -245,3 +315,147 @@ def test_streaming_with_list_structure(xlam_tool_parser):
|
||||
assert hasattr(result, "tool_calls")
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_current_weather"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"parallel_tool_calls",
|
||||
"single_tool_with_think_tag",
|
||||
"single_tool_with_json_code_block",
|
||||
"single_tool_with_tool_calls_tag",
|
||||
"single_tool_with_tool_call_xml_tags",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}, {"name": "get_current_weather", "arguments": {"city": "Orlando", "state": "FL", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""<think>I'll help you with that.</think>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"<think>I'll help you with that.</think>",
|
||||
),
|
||||
(
|
||||
"""```json\n[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]\n```""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""[TOOL_CALLS][{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"",
|
||||
),
|
||||
(
|
||||
"""I can help with that.<tool_call>[{"name": "get_current_weather", "arguments": {"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]</tool_call>""", # noqa: E501
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I can help with that.",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls_streaming_incremental(
|
||||
xlam_tool_parser,
|
||||
xlam_tokenizer,
|
||||
model_output,
|
||||
expected_tool_calls,
|
||||
expected_content,
|
||||
):
|
||||
"""Verify the XLAM Parser streaming behavior by verifying each chunk is as expected.""" # noqa: E501
|
||||
request = ChatCompletionRequest(model=MODEL, messages=[], tools=[])
|
||||
|
||||
chunks = []
|
||||
for delta_message in stream_delta_message_generator(
|
||||
xlam_tool_parser, xlam_tokenizer, model_output, request):
|
||||
chunks.append(delta_message)
|
||||
|
||||
# Should have multiple chunks
|
||||
assert len(chunks) >= 3
|
||||
|
||||
# Should have a chunk with tool header (id, name, type) for the first tool call # noqa: E501
|
||||
header_found = False
|
||||
expected_first_tool = expected_tool_calls[0]
|
||||
for chunk in chunks:
|
||||
if chunk.tool_calls and chunk.tool_calls[0].id:
|
||||
header_found = True
|
||||
assert (chunk.tool_calls[0].function.name ==
|
||||
expected_first_tool.function.name)
|
||||
assert chunk.tool_calls[0].type == "function"
|
||||
# Arguments may be empty initially or None
|
||||
if chunk.tool_calls[0].function.arguments is not None:
|
||||
# If present, should be empty string initially
|
||||
assert chunk.tool_calls[0].function.arguments == ""
|
||||
break
|
||||
assert header_found
|
||||
|
||||
# Should have chunks with incremental arguments
|
||||
arg_chunks = []
|
||||
for chunk in chunks:
|
||||
if (chunk.tool_calls and chunk.tool_calls[0].function.arguments
|
||||
and chunk.tool_calls[0].function.arguments != ""
|
||||
and chunk.tool_calls[0].index ==
|
||||
0 # Only collect arguments from the first tool call
|
||||
):
|
||||
arg_chunks.append(chunk.tool_calls[0].function.arguments)
|
||||
|
||||
# Arguments should be streamed incrementally
|
||||
assert len(arg_chunks) > 1
|
||||
|
||||
# Concatenated arguments should form valid JSON for the first tool call
|
||||
full_args = "".join(arg_chunks)
|
||||
parsed_args = json.loads(full_args)
|
||||
expected_args = json.loads(expected_first_tool.function.arguments)
|
||||
assert parsed_args == expected_args
|
||||
|
||||
@ -501,6 +501,8 @@ class ModelConfig:
|
||||
logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None
|
||||
"""One or more logits processors' fully-qualified class names or class
|
||||
definitions"""
|
||||
io_processor_plugin: Optional[str] = None
|
||||
"""IOProcessor plugin name to load at model startup"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
|
||||
@ -40,16 +40,21 @@ class KVCacheEvent(
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
MEDIUM_GPU = "GPU"
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
medium: Optional[str]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
|
||||
@ -19,6 +19,8 @@ The class provides the following primitives:
|
||||
Returns whether KV cache should be freed now or will be
|
||||
freed asynchronously and optionally returns KV transfer
|
||||
params.
|
||||
take_events() - returns new KV events that were collected
|
||||
by the connector since the last call.
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
@ -34,6 +36,7 @@ The class provides the following primitives:
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
@ -45,6 +48,7 @@ from vllm.v1.outputs import KVConnectorOutput
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
@ -313,6 +317,15 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return False, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
return ()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
|
||||
@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@ -208,6 +210,10 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
|
||||
@ -364,6 +364,7 @@ class EngineArgs:
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
io_processor_plugin: Optional[str] = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
@ -577,6 +578,8 @@ class EngineArgs:
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
model_group.add_argument("--logits-processors",
|
||||
**model_kwargs["logits_processors"])
|
||||
model_group.add_argument("--io-processor-plugin",
|
||||
**model_kwargs["io_processor_plugin"])
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
@ -993,6 +996,7 @@ class EngineArgs:
|
||||
model_impl=self.model_impl,
|
||||
override_attention_dtype=self.override_attention_dtype,
|
||||
logits_processors=self.logits_processors,
|
||||
io_processor_plugin=self.io_processor_plugin,
|
||||
)
|
||||
|
||||
def validate_tensorizer_args(self):
|
||||
@ -1432,17 +1436,6 @@ class EngineArgs:
|
||||
recommend_to_remove=True)
|
||||
return False
|
||||
|
||||
# Triton v3.3 has f16 conversion regression issue on Turing and Volta,
|
||||
# which broke fp16 inference
|
||||
# see: https://github.com/triton-lang/triton/issues/6698
|
||||
if (current_platform.is_cuda()
|
||||
and not current_platform.has_device_capability(80)
|
||||
and model_config.dtype == torch.float16):
|
||||
_raise_or_fallback(
|
||||
feature_name="Compute Capability < 8.0 with FP16",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
if self.kv_cache_dtype != "auto":
|
||||
supported = current_platform.is_kv_cache_dtype_supported(
|
||||
self.kv_cache_dtype, model_config)
|
||||
|
||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
@ -267,6 +268,9 @@ class EngineClient(ABC):
|
||||
"""Get the appropriate tokenizer for the request"""
|
||||
...
|
||||
|
||||
async def get_io_processor(self) -> IOProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_tracing_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
@ -37,13 +37,15 @@ from vllm.entrypoints.score_utils import (ScoreContentPartParam,
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.utils import (_validate_truncation_size,
|
||||
log_non_default_args)
|
||||
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
@ -284,6 +286,11 @@ class LLM:
|
||||
|
||||
self.supported_tasks = supported_tasks
|
||||
|
||||
# Load the Input/Output processor plugin if any
|
||||
io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(self.llm_engine.vllm_config,
|
||||
io_processor_plugin)
|
||||
|
||||
def get_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
@ -833,7 +840,7 @@ class LLM:
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[PromptType, Sequence[PromptType]],
|
||||
prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
@ -915,6 +922,22 @@ class LLM:
|
||||
if truncate_prompt_tokens is not None:
|
||||
param.truncate_prompt_tokens = truncate_prompt_tokens
|
||||
|
||||
io_processor_prompt = False
|
||||
if isinstance(prompts, dict) and "data" in prompts:
|
||||
io_processor_prompt = True
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
# Validate the request data is valid for the loaded plugin
|
||||
validated_prompt = self.io_processor.parse_request(prompts)
|
||||
|
||||
# obtain the actual model prompts from the pre-processor
|
||||
prompts = self.io_processor.pre_process(prompt=validated_prompt)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
prompts=prompts,
|
||||
params=pooling_params,
|
||||
@ -923,8 +946,24 @@ class LLM:
|
||||
)
|
||||
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return self.engine_class.validate_outputs(outputs,
|
||||
PoolingRequestOutput)
|
||||
|
||||
model_outputs = self.engine_class.validate_outputs(
|
||||
outputs, PoolingRequestOutput)
|
||||
|
||||
if io_processor_prompt:
|
||||
# get the post-processed model outputs
|
||||
assert self.io_processor is not None
|
||||
processed_outputs = self.io_processor.post_process(
|
||||
model_output=model_outputs)
|
||||
|
||||
return [
|
||||
PoolingRequestOutput[Any](request_id="",
|
||||
outputs=processed_outputs,
|
||||
prompt_token_ids=[],
|
||||
finished=True)
|
||||
]
|
||||
else:
|
||||
return model_outputs
|
||||
|
||||
def embed(
|
||||
self,
|
||||
|
||||
@ -64,6 +64,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse,
|
||||
IOProcessorResponse,
|
||||
LoadLoRAAdapterRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
RerankRequest, RerankResponse,
|
||||
@ -795,7 +796,7 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.error.code)
|
||||
elif isinstance(generator, PoolingResponse):
|
||||
elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
assert_never(generator)
|
||||
@ -1782,7 +1783,7 @@ async def init_app_state(
|
||||
) if "generate" in supported_tasks else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
vllm_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
|
||||
@ -6,7 +6,8 @@
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, ClassVar, Literal, Optional, Union
|
||||
from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional,
|
||||
TypeVar, Union)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@ -1405,7 +1406,46 @@ EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest]
|
||||
|
||||
PoolingCompletionRequest = EmbeddingCompletionRequest
|
||||
PoolingChatRequest = EmbeddingChatRequest
|
||||
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class IOProcessorRequest(OpenAIBaseModel, Generic[T]):
|
||||
model: Optional[str] = None
|
||||
|
||||
priority: int = Field(default=0)
|
||||
"""
|
||||
The priority of the request (lower means earlier handling;
|
||||
default: 0). Any priority other than 0 will raise an error
|
||||
if the served model does not use priority scheduling.
|
||||
"""
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual input is processed
|
||||
by the plugin itself. Hence, we use a generic type for the request data
|
||||
"""
|
||||
|
||||
def to_pooling_params(self):
|
||||
return PoolingParams(task="encode")
|
||||
|
||||
|
||||
class IOProcessorResponse(OpenAIBaseModel, Generic[T]):
|
||||
|
||||
request_id: Optional[str] = None
|
||||
"""
|
||||
The request_id associated with this response
|
||||
"""
|
||||
created_at: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
||||
data: T
|
||||
"""
|
||||
When using plugins IOProcessor plugins, the actual output is generated
|
||||
by the plugin itself. Hence, we use a generic type for the response data
|
||||
"""
|
||||
|
||||
|
||||
PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest,
|
||||
IOProcessorRequest]
|
||||
|
||||
|
||||
class ScoreRequest(OpenAIBaseModel):
|
||||
|
||||
@ -49,9 +49,11 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse, ErrorInfo,
|
||||
ErrorResponse, PoolingResponse,
|
||||
RerankRequest, ResponsesRequest,
|
||||
ScoreRequest, ScoreResponse,
|
||||
ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
PoolingResponse, RerankRequest,
|
||||
ResponsesRequest, ScoreRequest,
|
||||
ScoreResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeResponse,
|
||||
@ -89,7 +91,7 @@ ChatLikeRequest = Union[ChatCompletionRequest, EmbeddingChatRequest,
|
||||
TokenizeChatRequest]
|
||||
SpeechToTextRequest = Union[TranscriptionRequest, TranslationRequest]
|
||||
AnyRequest = Union[CompletionLikeRequest, ChatLikeRequest, SpeechToTextRequest,
|
||||
ResponsesRequest]
|
||||
ResponsesRequest, IOProcessorRequest]
|
||||
|
||||
AnyResponse = Union[
|
||||
CompletionResponse,
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
@ -13,19 +13,25 @@ import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, ErrorResponse]:
|
||||
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
|
||||
if isinstance(request, PoolingChatRequest):
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
request_prompts: Sequence[RequestPrompt] = [
|
||||
""
|
||||
] * len(engine_prompts)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request.input,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
request.encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
@ -186,11 +186,31 @@ class xLAMToolParser(ToolParser):
|
||||
"""
|
||||
Extract tool calls for streaming mode.
|
||||
"""
|
||||
# Simplify detection: if it begins with "[" treat it as a function call
|
||||
is_function_call = (current_text.strip().startswith("["))
|
||||
# First, check for a definitive start of a tool call block.
|
||||
# This prevents premature parsing of incomplete output.
|
||||
stripped_text = current_text.strip()
|
||||
preprocessed_content, preprocessed_tool_calls = (
|
||||
self.preprocess_model_output(current_text))
|
||||
|
||||
# If not a function call, return normal content
|
||||
if not is_function_call:
|
||||
# For JSON code blocks, we need to detect them earlier, even if incomplete
|
||||
has_potential_json_block = ("```json" in current_text
|
||||
or "```\n[" in current_text
|
||||
or "[TOOL_CALLS]" in current_text
|
||||
or "<tool_call>" in current_text)
|
||||
|
||||
is_tool_call_block = (
|
||||
stripped_text.startswith("[")
|
||||
or stripped_text.startswith("<tool_call>")
|
||||
or stripped_text.startswith("[TOOL_CALLS]") or
|
||||
# Check if we have thinking tags with JSON-like content following
|
||||
("</think>[" in current_text) or
|
||||
# Check if the text contains a JSON array after preprocessing
|
||||
preprocessed_tool_calls is not None or
|
||||
# For JSON code blocks, detect early if we see enough structure
|
||||
(has_potential_json_block and '"name"' in current_text
|
||||
and '"arguments"' in current_text))
|
||||
|
||||
if not is_tool_call_block:
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
try:
|
||||
@ -204,7 +224,10 @@ class xLAMToolParser(ToolParser):
|
||||
|
||||
# Try parsing as JSON to check for complete tool calls
|
||||
try:
|
||||
parsed_tools = json.loads(current_text)
|
||||
# Use preprocessed tool calls if available
|
||||
tool_calls_text = (preprocessed_tool_calls if
|
||||
preprocessed_tool_calls else current_text)
|
||||
parsed_tools = json.loads(tool_calls_text)
|
||||
if isinstance(parsed_tools, list):
|
||||
# Update our tool array for next time
|
||||
self.prev_tool_call_arr = parsed_tools
|
||||
@ -257,13 +280,40 @@ class xLAMToolParser(ToolParser):
|
||||
return delta
|
||||
|
||||
# Use regex to identify tool calls in the output
|
||||
# Use preprocessed tool calls text for better parsing, but also try to extract from incomplete JSON blocks
|
||||
search_text = (preprocessed_tool_calls
|
||||
if preprocessed_tool_calls else current_text)
|
||||
|
||||
# For JSON code blocks that aren't complete yet, try to extract the JSON content
|
||||
if not preprocessed_tool_calls and has_potential_json_block:
|
||||
# Try to extract the JSON array from within the code block
|
||||
json_match = re.search(r"```(?:json)?\s*([\s\S]*?)(?:```|$)",
|
||||
current_text)
|
||||
if json_match:
|
||||
potential_json = json_match.group(1).strip()
|
||||
# Use this as search text even if it's incomplete
|
||||
if potential_json.startswith("[") and (
|
||||
'"name"' in potential_json
|
||||
and '"arguments"' in potential_json):
|
||||
search_text = potential_json
|
||||
|
||||
# Try to find complete tool names first
|
||||
name_pattern = r'"name"\s*:\s*"([^"]+)"'
|
||||
name_matches = list(re.finditer(name_pattern, current_text))
|
||||
name_matches = list(re.finditer(name_pattern, search_text))
|
||||
tool_count = len(name_matches)
|
||||
|
||||
# If no tools found yet, return
|
||||
# If no complete tool names found, check for partial tool names
|
||||
if tool_count == 0:
|
||||
return None
|
||||
# Check if we're in the middle of parsing a tool name
|
||||
partial_name_pattern = r'"name"\s*:\s*"([^"]*)'
|
||||
partial_matches = list(
|
||||
re.finditer(partial_name_pattern, search_text))
|
||||
if partial_matches:
|
||||
# We have a partial tool name - not ready to emit yet
|
||||
return None
|
||||
else:
|
||||
# No tools found at all
|
||||
return None
|
||||
|
||||
# Ensure our state arrays are large enough
|
||||
while len(self.streaming_state["sent_tools"]) < tool_count:
|
||||
@ -332,7 +382,7 @@ class xLAMToolParser(ToolParser):
|
||||
# First, check for the empty arguments case: "arguments": {}
|
||||
empty_args_pattern = (
|
||||
r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
|
||||
empty_args_match = re.search(empty_args_pattern, current_text)
|
||||
empty_args_match = re.search(empty_args_pattern, search_text)
|
||||
|
||||
# Check if this tool has empty arguments
|
||||
if empty_args_match and empty_args_match.start() > 0:
|
||||
@ -376,7 +426,7 @@ class xLAMToolParser(ToolParser):
|
||||
|
||||
# Extract arguments for current tool using regex for non-empty arguments
|
||||
args_pattern = r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
|
||||
args_matches = list(re.finditer(args_pattern, current_text))
|
||||
args_matches = list(re.finditer(args_pattern, search_text))
|
||||
|
||||
if current_idx < len(args_matches):
|
||||
args_text = args_matches[current_idx].group(1)
|
||||
@ -384,17 +434,25 @@ class xLAMToolParser(ToolParser):
|
||||
# Handle transition between tools
|
||||
is_last_tool = current_idx == tool_count - 1
|
||||
|
||||
# Find where the arguments for our current tool end
|
||||
if not is_last_tool:
|
||||
# If we have more tools after this one, try to find the complete argument block
|
||||
next_tool_pos = current_text.find(
|
||||
"},{", args_matches[current_idx].start())
|
||||
if next_tool_pos != -1:
|
||||
args_end_pos = (next_tool_pos + 1
|
||||
) # +1 to include the '}'
|
||||
args_text = (current_text[args_matches[current_idx]
|
||||
.start():args_end_pos].
|
||||
split('"arguments":')[1].strip())
|
||||
# For multiple tools, extract only the arguments for the current tool
|
||||
if tool_count > 1:
|
||||
# Parse the entire JSON structure to properly extract arguments for each tool
|
||||
try:
|
||||
parsed_tools = json.loads(search_text)
|
||||
if isinstance(
|
||||
parsed_tools,
|
||||
list) and current_idx < len(parsed_tools):
|
||||
current_tool = parsed_tools[current_idx]
|
||||
if isinstance(current_tool.get("arguments"),
|
||||
dict):
|
||||
args_text = json.dumps(
|
||||
current_tool["arguments"])
|
||||
else:
|
||||
args_text = str(
|
||||
current_tool.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, KeyError, IndexError):
|
||||
# Fallback to regex-based extraction
|
||||
pass
|
||||
|
||||
# If arguments haven't been sent yet
|
||||
sent_args = self.streaming_state["sent_tools"][
|
||||
@ -419,7 +477,7 @@ class xLAMToolParser(ToolParser):
|
||||
index=current_idx,
|
||||
function=DeltaFunctionCall(
|
||||
arguments="{").model_dump(
|
||||
exclude_none=True), # type: ignore
|
||||
exclude_none=True), # type: ignore
|
||||
)
|
||||
])
|
||||
return delta
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
from .data import (DataPrompt, DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
|
||||
EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokenInputs, TokensPrompt,
|
||||
@ -18,6 +18,7 @@ target model.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"DataPrompt",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"PromptType",
|
||||
|
||||
@ -95,6 +95,16 @@ class EmbedsPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class DataPrompt(TypedDict):
|
||||
"""Represents generic inputs handled by IO processor plugins."""
|
||||
|
||||
data: Any
|
||||
"""The input data"""
|
||||
|
||||
data_format: str
|
||||
"""The input data format"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||
"""
|
||||
Set of possible schemas for a single prompt:
|
||||
|
||||
@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
autoround_version: str = "",
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig):
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig):
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
autoround_version = cls.get_from_keys_or(config, ["autoround_version"],
|
||||
default="")
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic)
|
||||
dynamic, autoround_version)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
|
||||
@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = full_config.get("autoround_version", "")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
||||
f"group_size={self.group_size}, "
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from copy import deepcopy
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union
|
||||
|
||||
import regex as re
|
||||
@ -29,7 +30,7 @@ def override_config(config: QuantizationConfig, prefix: str):
|
||||
if isinstance(desc_act, bool):
|
||||
config.desc_act = desc_act
|
||||
|
||||
config.pack_factor = 32 // config.weight_bits # packed into int32
|
||||
config.pack_factor = Fraction(32, config.weight_bits) # packed into int32
|
||||
if config.get_name() == "gptq_marlin":
|
||||
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
||||
if isinstance(is_sym, bool):
|
||||
|
||||
@ -549,7 +549,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@ -1371,7 +1371,8 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
output_tensor[i, :t.size(0)] = t
|
||||
return output_tensor
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[MllamaImagePixelInputs]:
|
||||
# tensor with the same shape will be batched together by
|
||||
# MultiModalKwargs.batch, so pixel_values here can be:
|
||||
# - list[torch.Tensor]:
|
||||
|
||||
@ -209,7 +209,7 @@ class OvisImagePatchInputs(TypedDict):
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
inducator_tokens: torch.Tensor
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
""" PyTorch Ovis model."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class OvisVideoPatchInputs(TypedDict):
|
||||
type: Literal["video_patches"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
"""
|
||||
|
||||
patches_per_image: list[int]
|
||||
"""
|
||||
List of number of total patches for each frame in the video.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
"""
|
||||
|
||||
|
||||
def _ovis2_5_field_config():
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
grids=MultiModalFieldConfig.batched("image"),
|
||||
@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_visual_input(
|
||||
self, is_video,
|
||||
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
if is_video:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
else:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None and indicator_tokens is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of indicator_tokens. "
|
||||
f"Got type: {type(indicator_tokens)}")
|
||||
|
||||
return OvisVideoPatchInputs(
|
||||
type="video_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
patches_per_image=[
|
||||
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
||||
for x in flatten_bn(pixel_values)
|
||||
],
|
||||
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||
concat=True),
|
||||
grids=flatten_bn(flatten_bn(grids), concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
||||
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
|
||||
) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
indicator_tokens = image_input["indicator_tokens"]
|
||||
@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
torch.cat(vision_embeddings_per_image, dim=0))
|
||||
return tuple(vision_embeddings)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
embeddings = []
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
||||
# keys from kwargs. We process images first, then videos.
|
||||
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
||||
if image_input:
|
||||
embeddings.extend(self._process_image_input(image_input))
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "indicator_tokens",
|
||||
"grids") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("video_pixel_values", "video_indicator_tokens",
|
||||
"video_grids") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
||||
if video_input:
|
||||
embeddings.extend(self._process_image_input(video_input))
|
||||
return modalities
|
||||
|
||||
return tuple(embeddings) if embeddings else None
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -40,6 +40,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
@ -615,50 +616,90 @@ class Phi4MMAudioEmbedding(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Phi4MMImagePixelInputs(TypedDict):
|
||||
class Phi4MMImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- p: Number of patches (1 + num_patches)
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image patch
|
||||
- w: Width of each image patch
|
||||
- nc: Number of crops
|
||||
- H_mask: Height of attention mask
|
||||
- W_mask: Width of attention mask
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
), # may be different per batch and image
|
||||
]
|
||||
|
||||
image_sizes: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", 2), # (height, width)
|
||||
]
|
||||
|
||||
num_img_tokens: Annotated[
|
||||
list[int],
|
||||
TensorShape("bn"),
|
||||
]
|
||||
|
||||
image_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
|
||||
]
|
||||
|
||||
|
||||
class Phi4MMImageEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- f: Image feature size
|
||||
- h: Hidden size (must match language model backbone)
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
num_img_tokens: list[int]
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
image_attention_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
|
||||
|
||||
|
||||
class Phi4MMImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "f", "h"),
|
||||
]
|
||||
|
||||
|
||||
class Phi4MMAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- f: Number of Mel filterbank bins (80)
|
||||
- t: Time frames (M)
|
||||
"""
|
||||
|
||||
|
||||
class Phi4MMAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_audios, 80, M)"""
|
||||
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
|
||||
]
|
||||
|
||||
|
||||
class Phi4MMAudioEmbeddingInputs(TypedDict):
|
||||
class Phi4MMAudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of audios
|
||||
- f: Audio feature size
|
||||
- h: Hidden size (must match language model backbone)
|
||||
"""
|
||||
|
||||
type: Literal["audio_embeds"]
|
||||
data: NestedTensors
|
||||
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
|
||||
|
||||
data: Annotated[
|
||||
NestedTensors,
|
||||
TensorShape("b", "n", "f", "h"),
|
||||
]
|
||||
|
||||
|
||||
Phi4MMImageInput = Union[Phi4MMImagePixelInputs, Phi4MMImageEmbeddingInputs]
|
||||
@ -1170,18 +1211,10 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
return None
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
return Phi4MMAudioFeatureInputs(type="audio_features",
|
||||
data=flatten_bn(audio_features))
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
|
||||
data=audio_embeds)
|
||||
|
||||
@ -1259,7 +1292,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
elif isinstance(image_sizes, torch.Tensor):
|
||||
image_sizes = image_sizes.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
raise ValueError("Incorrect image_sizes inputs")
|
||||
|
||||
if isinstance(num_img_tokens, list):
|
||||
num_img_tokens = [
|
||||
@ -1269,7 +1302,7 @@ class Phi4MultimodalForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
elif isinstance(num_img_tokens, torch.Tensor):
|
||||
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
raise ValueError("Incorrect num_img_tokens inputs")
|
||||
|
||||
return Phi4MMImagePixelInputs(
|
||||
type="pixel_values",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
from typing import Annotated, Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -31,6 +31,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||
@ -391,41 +392,71 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
return img_set_tensor
|
||||
|
||||
|
||||
class Phi4MMImagePixelInputs(TypedDict):
|
||||
class Phi4MMImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- p: Number of patches (1 + num_patches)
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image patch
|
||||
- w: Width of each image patch
|
||||
- nc: Number of crops
|
||||
- H_mask: Height of attention mask
|
||||
- W_mask: Width of attention mask
|
||||
"""
|
||||
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different per batch and image,
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
|
||||
), # may be different per batch and image
|
||||
]
|
||||
|
||||
image_sizes: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", 2), # (height, width)
|
||||
]
|
||||
|
||||
num_img_tokens: Annotated[
|
||||
list[int],
|
||||
TensorShape("bn"),
|
||||
]
|
||||
|
||||
image_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("bn", "nc", 32, 32), # H_mask, W_mask
|
||||
]
|
||||
|
||||
|
||||
class Phi4MMAudioFeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- t: Time frames (M)
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
num_img_tokens: list[int]
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
image_attention_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, H_mask, W_mask)`"""
|
||||
|
||||
|
||||
class Phi4MMAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_audios, 80, M)"""
|
||||
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "t", 80, dynamic_dims={"t"}),
|
||||
]
|
||||
|
||||
|
||||
class Phi4MMAudioEmbeddingInputs(TypedDict):
|
||||
class Phi4MMAudioEmbeddingInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- n: Number of audios
|
||||
- f: Audio feature size
|
||||
- h: Hidden size (must match language model backbone)
|
||||
"""
|
||||
type: Literal["audio_embeds"]
|
||||
data: NestedTensors
|
||||
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
|
||||
data: Annotated[
|
||||
NestedTensors,
|
||||
TensorShape("b", "n", "f", "h"),
|
||||
]
|
||||
|
||||
|
||||
Phi4MMAudioInputs = Union[Phi4MMAudioFeatureInputs, Phi4MMAudioEmbeddingInputs]
|
||||
@ -985,18 +1016,10 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
return None
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
return Phi4MMAudioFeatureInputs(type="audio_features",
|
||||
data=flatten_bn(audio_features))
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
return Phi4MMAudioEmbeddingInputs(type="audio_embeds",
|
||||
data=audio_embeds)
|
||||
|
||||
@ -1031,8 +1054,8 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
]
|
||||
return audio_embeds
|
||||
|
||||
def _parse_and_validate_image_input(self,
|
||||
**kwargs: object) -> Optional[dict]:
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Phi4MMImagePixelInputs]:
|
||||
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
||||
if input_image_embeds is None:
|
||||
return None
|
||||
@ -1074,7 +1097,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
elif isinstance(image_sizes, torch.Tensor):
|
||||
image_sizes = image_sizes.flatten(0, 1)
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
raise ValueError("Incorrect image_sizes inputs")
|
||||
|
||||
if isinstance(num_img_tokens, list):
|
||||
num_img_tokens = [
|
||||
@ -1084,7 +1107,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
elif isinstance(num_img_tokens, torch.Tensor):
|
||||
num_img_tokens = num_img_tokens.flatten(0, 1).tolist()
|
||||
else:
|
||||
raise ValueError("Incorrect image_attention_mask inputs")
|
||||
raise ValueError("Incorrect num_img_tokens inputs")
|
||||
|
||||
return Phi4MMImagePixelInputs(
|
||||
type="pixel_values",
|
||||
|
||||
@ -159,9 +159,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid gate quantization.
|
||||
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
# seems to avoid gate quantization while AutoRound does.
|
||||
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4,
|
||||
# and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq
|
||||
if isinstance(
|
||||
quant_config,
|
||||
(GPTQConfig,
|
||||
GPTQMarlinConfig)) and not quant_config.autoround_version:
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
|
||||
68
vllm/plugins/io_processors/__init__.py
Normal file
68
vllm/plugins/io_processors/__init__.py
Normal file
@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.plugins import load_plugins_by_group
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_io_processor(
|
||||
vllm_config: VllmConfig,
|
||||
plugin_from_init: Optional[str] = None) -> IOProcessor | None:
|
||||
# Input.Output processors are loaded as plugins under the
|
||||
# 'vllm.io_processor_plugins' group. Similar to platform
|
||||
# plugins, these plugins register a function that returns the class
|
||||
# name for the processor to install.
|
||||
|
||||
if plugin_from_init:
|
||||
model_plugin = plugin_from_init
|
||||
else:
|
||||
# A plugin can be specified via the model config
|
||||
# Retrieve the model specific plugin if available
|
||||
# This is using a custom field in the hf_config for the model
|
||||
hf_config = vllm_config.model_config.hf_config.to_dict()
|
||||
config_plugin = hf_config.get("io_processor_plugin")
|
||||
model_plugin = config_plugin
|
||||
|
||||
if model_plugin is None:
|
||||
logger.info("No IOProcessor plugins requested by the model")
|
||||
return None
|
||||
|
||||
logger.debug("IOProcessor plugin to be loaded %s", model_plugin)
|
||||
|
||||
# Load all installed plugin in the group
|
||||
multimodal_data_processor_plugins = \
|
||||
load_plugins_by_group('vllm.io_processor_plugins')
|
||||
|
||||
loadable_plugins = {}
|
||||
for name, func in multimodal_data_processor_plugins.items():
|
||||
try:
|
||||
assert callable(func)
|
||||
processor_cls_qualname = func()
|
||||
if processor_cls_qualname is not None:
|
||||
loadable_plugins[name] = processor_cls_qualname
|
||||
except Exception:
|
||||
logger.warning("Failed to load plugin %s.", name, exc_info=True)
|
||||
|
||||
num_available_plugins = len(loadable_plugins.keys())
|
||||
if num_available_plugins == 0:
|
||||
raise ValueError("No IOProcessor plugins installed"
|
||||
f" but one is required ({model_plugin}).")
|
||||
|
||||
if model_plugin not in loadable_plugins:
|
||||
raise ValueError(
|
||||
f"The model requires the '{model_plugin}' IO Processor plugin "
|
||||
"but it is not installed. "
|
||||
f"Available plugins: {list(loadable_plugins.keys())}")
|
||||
|
||||
activated_plugin_cls = loadable_plugins[model_plugin]
|
||||
|
||||
return resolve_obj_by_qualname(activated_plugin_cls)(vllm_config)
|
||||
62
vllm/plugins/io_processors/interface.py
Normal file
62
vllm/plugins/io_processors/interface.py
Normal file
@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Any, Generic, Optional, TypeVar, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
|
||||
IOProcessorInput = TypeVar('IOProcessorInput')
|
||||
IOProcessorOutput = TypeVar('IOProcessorOutput')
|
||||
|
||||
|
||||
class IOProcessor(ABC, Generic[IOProcessorInput, IOProcessorOutput]):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
@abstractmethod
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def pre_process_async(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[PromptType, Sequence[PromptType]]:
|
||||
return self.pre_process(prompt, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def post_process(self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs) -> IOProcessorOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
async def post_process_async(
|
||||
self,
|
||||
model_output: AsyncGenerator[tuple[int, PoolingRequestOutput]],
|
||||
request_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
collected_output = [item async for i, item in model_output]
|
||||
return self.post_process(collected_output, request_id, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput) -> IOProcessorResponse:
|
||||
raise NotImplementedError
|
||||
@ -3290,7 +3290,7 @@ def sha256_cbor_64bit(input) -> int:
|
||||
return full_hash & ((1 << 64) - 1)
|
||||
|
||||
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable:
|
||||
def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
|
||||
"""Get a hash function by name, or raise an error if
|
||||
the function is not found.
|
||||
Args:
|
||||
|
||||
@ -4,8 +4,9 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
|
||||
BlockRemoved, BlockStored,
|
||||
KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock)
|
||||
@ -156,6 +157,7 @@ class BlockPool:
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.id
|
||||
if request.lora_request else None,
|
||||
medium=MEDIUM_GPU,
|
||||
))
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
@ -218,7 +220,8 @@ class BlockPool:
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
|
||||
BlockRemoved(block_hashes=[block_hash.get_hash_value()],
|
||||
medium=MEDIUM_GPU))
|
||||
return True
|
||||
|
||||
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
|
||||
|
||||
@ -527,6 +527,7 @@ def hash_block_tokens(
|
||||
hash values for the same block contents.
|
||||
|
||||
Args:
|
||||
hash_function: The hash function used to compute block hash.
|
||||
parent_block_hash: The hash of the parent block. None
|
||||
if this is the first block.
|
||||
curr_block_token_ids: A list of token ids in the current
|
||||
|
||||
@ -589,7 +589,19 @@ class Scheduler(SchedulerInterface):
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# collect KV cache events from KV cache manager
|
||||
events = self.kv_cache_manager.take_events()
|
||||
|
||||
# collect KV cache events from connector
|
||||
if self.connector is not None:
|
||||
connector_events = self.connector.take_events()
|
||||
if connector_events:
|
||||
if events is None:
|
||||
events = list(connector_events)
|
||||
else:
|
||||
events.extend(connector_events)
|
||||
|
||||
# publish collected KV cache events
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
@ -1207,7 +1219,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Now that the blocks are ready, actually cache them.
|
||||
(block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
# Handle the case where num request tokens less then one block.
|
||||
# Handle the case where num request tokens less than one block.
|
||||
num_computed_tokens = min(num_computed_tokens, request.num_tokens)
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
|
||||
@ -47,7 +47,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for reempted ones.
|
||||
# data for preempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
|
||||
@ -138,14 +138,14 @@ def _torch_cuda_wrapper():
|
||||
|
||||
@contextmanager
|
||||
def _set_global_compilation_settings(config: VllmConfig):
|
||||
import torch._inductor.config
|
||||
import torch._inductor.config as torch_inductor_config
|
||||
|
||||
inductor_config = config.compilation_config.inductor_compile_config
|
||||
# Note: The MKLDNN and CPPGEMM backend requires freezing parameters.
|
||||
freezing_value = torch._inductor.config.freezing
|
||||
freezing_value = torch_inductor_config.freezing
|
||||
try:
|
||||
if inductor_config.get("max_autotune", False):
|
||||
torch._inductor.config.freezing = True
|
||||
torch_inductor_config.freezing = True
|
||||
yield
|
||||
finally:
|
||||
torch._inductor.config.freezing = freezing_value
|
||||
torch_inductor_config.freezing = freezing_value
|
||||
|
||||
@ -600,28 +600,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
logits_indices_padded = None
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
assert self.kv_sharing_fast_prefill_logits_indices is not None
|
||||
num_logits = logits_indices.shape[0]
|
||||
assert num_logits > 0
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
|
||||
logits_indices_padded = self._prepare_kv_sharing_fast_prefill(
|
||||
logits_indices)
|
||||
# There might have leftover indices in logits_indices[num_logits:]
|
||||
# from previous iterations, whose values may be greater than the
|
||||
# batch size in the current iteration. To ensure indices are always
|
||||
# valid, we fill the padded indices with the last index.
|
||||
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
|
||||
logits_indices[-1].item())
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_logits <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_logits_padded = self.vllm_config.pad_for_cudagraph(
|
||||
num_logits)
|
||||
else:
|
||||
num_logits_padded = num_logits
|
||||
logits_indices_padded = (
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded]
|
||||
)
|
||||
|
||||
# Used in the below loop.
|
||||
query_start_loc_cpu = self.query_start_loc.cpu[:num_reqs + 1]
|
||||
@ -644,13 +624,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device, non_blocking=True)
|
||||
slot_mapping = torch.zeros((total_num_scheduled_tokens, ),
|
||||
dtype=torch.int32,
|
||||
pin_memory=self.pin_memory,
|
||||
device="cpu").to(self.device,
|
||||
non_blocking=True)
|
||||
device=self.device,
|
||||
)
|
||||
slot_mapping = torch.zeros(
|
||||
(total_num_scheduled_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
num_common_prefix_blocks = 0
|
||||
else:
|
||||
blk_table_tensor = block_tables[kv_cache_group_id]
|
||||
@ -889,6 +869,32 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
mrope_pos_ptr += completion_part_len
|
||||
|
||||
def _prepare_kv_sharing_fast_prefill(
|
||||
self,
|
||||
logits_indices: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.kv_sharing_fast_prefill_logits_indices is not None
|
||||
num_logits = logits_indices.shape[0]
|
||||
assert num_logits > 0
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits].copy_(
|
||||
logits_indices)
|
||||
# There might have leftover indices in logits_indices[num_logits:]
|
||||
# from previous iterations, whose values may be greater than the
|
||||
# batch size in the current iteration. To ensure indices are always
|
||||
# valid, we fill the padded indices with the last index.
|
||||
self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_(
|
||||
logits_indices[-1].item())
|
||||
if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
||||
and num_logits <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits)
|
||||
else:
|
||||
num_logits_padded = num_logits
|
||||
logits_indices_padded = (
|
||||
self.kv_sharing_fast_prefill_logits_indices[:num_logits_padded])
|
||||
return logits_indices_padded
|
||||
|
||||
def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
|
||||
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
|
||||
if not scheduled_encoder_inputs:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user