[Bugfix] Fix broadcasting logic for multi_modal_kwargs (#6836)

This commit is contained in:
Cyrus Leung 2024-07-31 10:38:45 +08:00 committed by GitHub
parent da1f7cc12a
commit f230cc2ca6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 254 additions and 211 deletions

View File

@ -56,7 +56,6 @@ steps:
fast_check: true
commands:
- pytest -v -s core
- pytest -v -s distributed/test_parallel_state.py
- label: Distributed Comm Ops Test
#mirror_hardwares: [amd]
@ -90,13 +89,13 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- TEST_DIST_MODEL=llava-hf/llava-v1.6-mistral-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

View File

@ -44,6 +44,8 @@ Base Classes
.. autodata:: vllm.multimodal.BatchedTensors
.. autodata:: vllm.multimodal.BatchedTensorInputs
.. autoclass:: vllm.multimodal.MultiModalDataBuiltins
:members:
:show-inheritance:

View File

@ -19,10 +19,10 @@ from vllm.utils import cuda_device_count_stateless
model = os.environ["TEST_DIST_MODEL"]
if model.startswith("llava-hf/llava"):
if model.startswith("llava-hf/llava-1.5"):
from ..models.test_llava import models, run_test
elif model.startswith("microsoft/Phi-3-vision"):
from ..models.test_phi3v import models, run_test
elif model.startswith("llava-hf/llava-v1.6"):
from ..models.test_llava_next import models, run_test
else:
raise NotImplementedError(f"Unsupported model: {model}")
@ -45,7 +45,8 @@ def test_models(hf_runner, vllm_runner, image_assets,
vllm_runner,
image_assets,
model=models[0],
size_factors=[1.0],
# So that LLaVA-NeXT processor may return nested list
size_factors=[0.25, 0.5, 1.0],
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,

View File

@ -1,57 +0,0 @@
from typing import Any, Dict
import pytest
import torch
from vllm.distributed.parallel_state import (_split_tensor_dict,
_update_nested_dict)
def test_split_tensor_dict():
test_dict = {
"key_a": "a",
"key_b": torch.arange(8, dtype=torch.float32),
"key_c": {
"key_1": torch.arange(5, dtype=torch.float32),
"key_2": torch.tensor([], dtype=torch.float32),
"key_3": 123,
},
"key_d": {},
}
metadata_list, tensor_list = _split_tensor_dict(test_dict)
assert len(metadata_list) == 6
assert torch.allclose(tensor_list[0], test_dict["key_b"])
assert torch.allclose(tensor_list[1], test_dict["key_c"]["key_1"])
assert torch.allclose(tensor_list[2], test_dict["key_c"]["key_2"])
def test_split_tensor_dict_invalid_key():
test_dict = {
"a%b": "a",
}
with pytest.raises(AssertionError):
_split_tensor_dict(test_dict)
def test_update_nested_dict():
flattened_keys_values = [("key1%key2%key3", "value1"),
("key1%key2%key4", "value2"),
("key1%key5", "value3"), ("key6%key7", "value4"),
("key8", "value5")]
res: Dict[str, Any] = {}
for flat_key, value in flattened_keys_values:
_update_nested_dict(res, flat_key, value)
assert res == {
"key1": {
"key2": {
"key3": "value1",
"key4": "value2"
},
"key5": "value3"
},
"key6": {
"key7": "value4"
},
"key8": "value5"
}

View File

@ -1,14 +1,12 @@
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Type
import pytest
from transformers import AutoConfig, AutoTokenizer
from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)
from vllm.multimodal.utils import rescale_image_size
from vllm.sequence import SampleLogprobs
from ..conftest import IMAGE_ASSETS
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from .utils import check_logprobs_close
pytestmark = pytest.mark.vlm
@ -27,6 +25,8 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
IMAGE_TOKEN_ID = 32000
models = ["llava-hf/llava-v1.6-vicuna-7b-hf"]
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
Optional[SampleLogprobs]],
@ -50,34 +50,19 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
return hf_output_ids, hf_output_str, out_logprobs
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-vicuna-7b-hf"])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
def run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
image_assets: _ImageAssets,
model: str,
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
@ -89,6 +74,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as vllm_model:
vllm_outputs_per_image = [
vllm_model.generate_greedy_logprobs(prompts,
@ -122,9 +109,54 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
)
@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
dtype, max_tokens, num_logprobs) -> None:
"""Inference result should be the same between hf and vllm.
All the image fixtures for the test is under tests/images.
For huggingface runner, we provide the PIL images as input.
For vllm runner, we provide MultiModalDataDict objects
and corresponding vision language config as input.
Note, the text input is also adjusted to abide by vllm contract.
The text output is sanitized to be able to compare with hf.
"""
run_test(
hf_runner,
vllm_runner,
image_assets,
model,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)
@pytest.mark.parametrize("height_and_width_and_result", [(1669, 2560, 2144),
(183, 488, 776)])
def test_image_feature_size(height_and_width_and_result):
# Avoid initializing CUDA too early in distributed tests
from vllm.model_executor.models.llava_next import (
get_llava_next_image_feature_size)
height, width, result = height_and_width_and_result
config = AutoConfig.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
assert get_llava_next_image_feature_size(config,

View File

@ -45,22 +45,16 @@ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]],
prefix: str = "") -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
If the Tensor is nested under `tensor_dict["key1"]["key2"]`, the key of its
metadata will be "key1%key2".
"""
metadata_list: List[Tuple[str, Any]] = []
tensor_list = []
tensor_list: List[torch.Tensor] = []
for key, value in tensor_dict.items():
assert "%" not in key, (
"Avoid having '%' in key "
"as it is used as a separator for nested entries.")
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
@ -68,31 +62,13 @@ def _split_tensor_dict(
# receiving side will set the device index.
device = value.device.type
metadata_list.append(
(prefix + key, TensorMetadata(device, value.dtype,
value.size())))
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
elif isinstance(value, dict):
if len(value) == 0:
metadata_list.append((prefix + key, value))
inner_metadata_list, inner_tensor_list = _split_tensor_dict(
value, prefix + key + "%")
metadata_list.extend(inner_metadata_list)
tensor_list.extend(inner_tensor_list)
else:
metadata_list.append((prefix + key, value))
metadata_list.append((key, value))
return metadata_list, tensor_list
def _update_nested_dict(nested_dict, flattened_key, value):
key_splits = flattened_key.split("%")
cur_dict = nested_dict
for k in key_splits[:-1]:
if k not in cur_dict:
cur_dict[k] = {}
cur_dict = cur_dict[k]
cur_dict[key_splits[-1]] = value
class GroupCoordinator:
"""
PyTorch ProcessGroup wrapper for a group of processes.
@ -566,7 +542,7 @@ class GroupCoordinator:
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
@ -583,9 +559,9 @@ class GroupCoordinator:
group=group,
async_op=True)
async_handles.append(handle)
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
else:
_update_nested_dict(tensor_dict, key, value)
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
@ -661,7 +637,7 @@ class GroupCoordinator:
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
@ -673,9 +649,9 @@ class GroupCoordinator:
torch.distributed.recv(tensor,
src=self.ranks[src],
group=group)
_update_nested_dict(tensor_dict, key, tensor)
tensor_dict[key] = tensor
else:
_update_nested_dict(tensor_dict, key, value)
tensor_dict[key] = value
return tensor_dict
def barrier(self):

View File

@ -1,5 +1,6 @@
from .base import (BatchedTensors, MultiModalDataBuiltins, MultiModalDataDict,
MultiModalInputs, MultiModalPlugin, NestedTensors)
from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins,
MultiModalDataDict, MultiModalInputs, MultiModalPlugin,
NestedTensors)
from .registry import MultiModalRegistry
MULTIMODAL_REGISTRY = MultiModalRegistry()
@ -12,6 +13,7 @@ See also:
"""
__all__ = [
"BatchedTensorInputs",
"BatchedTensors",
"MultiModalDataBuiltins",
"MultiModalDataDict",

View File

@ -9,10 +9,12 @@ import torch
import torch.types
from PIL import Image
from torch import nn
from typing_extensions import TypeAlias
from vllm.config import ModelConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.utils import JSONTree, json_map_leaves
logger = init_logger(__name__)
@ -22,11 +24,16 @@ Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors = Union[GenericSequence[NestedTensors], NestedTensors]
BatchedTensors: TypeAlias = JSONTree[torch.Tensor]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of :class:`NestedTensors` with one element
per item in the batch.
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
if sys.version_info < (3, 9):
@ -46,14 +53,17 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@staticmethod
def try_concat(
def _try_concat(
tensors: List[NestedTensors],
*,
device: torch.types.Device,
) -> BatchedTensors:
) -> Union[GenericSequence[NestedTensors], NestedTensors]:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
"""
# may be list rather than tensors
if isinstance(tensors[0], list):
return [[t.to(device=device) for t in tensor[0]]
return [[t for t in tensor[0]]
for tensor in cast(List[List[torch.Tensor]], tensors)]
tensors_ = cast(List[torch.Tensor], tensors)
@ -62,18 +72,21 @@ class MultiModalInputs(_MultiModalInputsBase):
for tensor in tensors_:
if tensor.shape[1:] != unbatched_shape:
return [
tensor.squeeze(0).to(device=device) for tensor in tensors_
]
return [tensor.squeeze(0) for tensor in tensors_]
return torch.cat(tensors_, dim=0).to(device=device)
return torch.cat(tensors_, dim=0)
@staticmethod
def batch(
inputs_list: List["MultiModalInputs"],
device: torch.types.Device,
) -> Dict[str, BatchedTensors]:
"""Batch multiple inputs together into a dictionary."""
def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if len(inputs_list) == 0:
return {}
@ -90,9 +103,18 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists[k].append(v)
return {
k: MultiModalInputs.try_concat(item_list, device=device)
k: MultiModalInputs._try_concat(item_list)
for k, item_list in item_lists.items()
}
} # type: ignore
@staticmethod
def as_kwargs(
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
) -> BatchedTensorInputs:
return json_map_leaves(lambda x: x.to(device, non_blocking=True),
batched_inputs)
class MultiModalDataBuiltins(TypedDict, total=False):

View File

@ -15,6 +15,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalInputs
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SamplerOutput)
from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
@ -323,7 +324,8 @@ class TP1DraftModelRunner(ModelRunner):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
)
# Compute the logits.

View File

@ -17,7 +17,7 @@ from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar,
Union)
Union, overload)
import numpy as np
import numpy.typing as npt
@ -53,6 +53,7 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
P = ParamSpec('P')
K = TypeVar("K")
T = TypeVar("T")
U = TypeVar("U")
class _Sentinel:
@ -712,6 +713,54 @@ def merge_dicts(dict1: Dict[K, List[T]],
return dict(merged_dict)
JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"],
Tuple["JSONTree[T]", ...], T]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Dict[str, JSONTree[T]],
) -> Dict[str, JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: List[JSONTree[T]],
) -> List[JSONTree[U]]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: Tuple[JSONTree[T], ...],
) -> Tuple[JSONTree[U], ...]:
...
@overload
def json_map_leaves(
func: Callable[[T], U],
value: JSONTree[T],
) -> JSONTree[U]:
...
def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]:
if isinstance(value, dict):
return {k: json_map_leaves(func, v) for k, v in value.items()}
elif isinstance(value, list):
return [json_map_leaves(func, v) for v in value]
elif isinstance(value, tuple):
return tuple(json_map_leaves(func, v) for v in value)
else:
return func(value)
def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
"""Flatten a list of lists to a single list."""
return [item for sublist in lists for item in sublist]

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
from torch import nn
@ -12,7 +11,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -41,7 +40,7 @@ class CPUModelInput(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
virtual_engine: Optional[int] = None
def as_broadcastable_tensor_dict(
@ -136,7 +135,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]:
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
@ -214,8 +213,7 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
slot_mapping=slot_mapping,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)
@ -361,11 +359,16 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
model_executable = self.model
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)

View File

@ -8,6 +8,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig, SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalInputs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
@ -99,11 +100,16 @@ class EmbeddingModelRunner(
kv_caches = [None] * num_layers
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)

View File

@ -4,8 +4,8 @@ import time
import warnings
import weakref
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set,
Tuple, Type, TypeVar, Union)
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type,
TypeVar, Union)
import numpy as np
import torch
@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.interfaces import (supports_lora,
supports_vision)
from vllm.model_executor.models.utils import set_cpu_offload_max_bytes
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.prompt_adapter.layers import PromptAdapterMapping
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -94,7 +94,7 @@ class ModelInputForGPU(ModelRunnerInputBase):
attn_metadata: Optional["AttentionMetadata"] = None
prompt_adapter_mapping: Optional[PromptAdapterMapping] = None
prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None
finished_requests_ids: Optional[List[str]] = None
virtual_engine: int = 0
@ -608,8 +608,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
data.multi_modal_inputs for data in self.inter_data_list
if data.multi_modal_inputs is not None
]
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.runner.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return self.model_input_cls(
input_tokens=input_tokens_tensor,
@ -1361,7 +1360,8 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multi_modal_kwargs,
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
device=self.device),
**seqlen_agnostic_kwargs)
# Compute the logits in the last pipeline stage.

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@ -10,7 +9,7 @@ from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.neuron import get_neuron_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import (IntermediateTensors, SamplerOutput,
SequenceGroupMetadata)
@ -32,7 +31,7 @@ class ModelInputForNeuron(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
input_block_ids: Optional[torch.Tensor] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
@ -84,8 +83,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], Mapping[
str, BatchedTensors]]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int],
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
@ -134,8 +133,7 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
dtype=torch.long,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, input_block_ids, seq_lens,
multi_modal_kwargs)
@ -244,7 +242,8 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
input_block_ids=model_input.input_block_ids,
**(model_input.multi_modal_kwargs or {}),
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
)
# Compute the logits.

View File

@ -1,4 +1,4 @@
from typing import List, Mapping, NamedTuple, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple
import openvino as ov
import torch
@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader.openvino import get_model
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -25,7 +25,7 @@ class ModelInput(NamedTuple):
attn_metadata: Optional[OpenVINOAttentionMetadata]
seq_lens: List[int]
query_lens: List[int]
multi_modal_kwargs: Mapping[str, BatchedTensors]
multi_modal_kwargs: BatchedTensorInputs
@classmethod
def empty(cls, device):
@ -265,8 +265,7 @@ class OpenVINOModelRunner:
max_context_len=max_context_len_tensor,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return ModelInput(
input_tokens,
@ -281,7 +280,7 @@ class OpenVINOModelRunner:
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata,
SamplingMetadata, Mapping[str, BatchedTensors]]:
SamplingMetadata, BatchedTensorInputs]:
# Prepare input tensors.
(
input_tokens,
@ -324,11 +323,16 @@ class OpenVINOModelRunner:
model_executable = self.model
execute_model_kwargs = {
"input_ids": input_tokens,
"positions": input_positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
**(multi_modal_kwargs or {}),
"input_ids":
input_tokens,
"positions":
input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
attn_metadata,
**MultiModalInputs.as_kwargs(multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
Type, Union)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.nn as nn
@ -14,7 +13,7 @@ from vllm.inputs import INPUT_REGISTRY
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.interfaces import supports_vision
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensors,
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
MultiModalInputs)
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, SamplerOutput,
@ -49,7 +48,7 @@ class ModelInputForXPU(ModelRunnerInputBase):
input_positions: Optional[torch.Tensor] = None
attn_metadata: Optional["AttentionMetadata"] = None
sampling_metadata: Optional["SamplingMetadata"] = None
multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
def as_broadcastable_tensor_dict(
self) -> Dict[str, Union[int, torch.Tensor]]:
@ -376,11 +375,16 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
model_executable = self.model
execute_model_kwargs = {
"input_ids": model_input.input_tokens,
"positions": model_input.input_positions,
"kv_caches": kv_caches,
"attn_metadata": model_input.attn_metadata,
**(model_input.multi_modal_kwargs or {}),
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
}
hidden_states = model_executable(**execute_model_kwargs)
@ -404,7 +408,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Mapping[str, BatchedTensors]]:
BatchedTensorInputs]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
@ -496,8 +500,7 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPU]):
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list,
device=self.device)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
return (input_tokens, input_positions, attn_metadata, seq_lens,
multi_modal_kwargs)