mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 00:25:49 +08:00
[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
de9cb61763
commit
4dff91c93d
@ -7,9 +7,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||||
MultiModalFieldElem, MultiModalKwargsItem,
|
|
||||||
PlaceholderRange)
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
from vllm.utils import GiB_bytes, sha256, sha256_cbor_64bit
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
@ -42,13 +40,7 @@ def make_request(
|
|||||||
if mm_positions is None:
|
if mm_positions is None:
|
||||||
mm_kwargs = None
|
mm_kwargs = None
|
||||||
else:
|
else:
|
||||||
mm_elem = MultiModalFieldElem(
|
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||||
modality="dummy_m",
|
|
||||||
key="dummy_k",
|
|
||||||
data=None,
|
|
||||||
field=MultiModalBatchedField(),
|
|
||||||
)
|
|
||||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
|
||||||
mm_kwargs = [mm_item] * len(mm_positions)
|
mm_kwargs = [mm_item] * len(mm_positions)
|
||||||
|
|
||||||
return Request(request_id=request_id,
|
return Request(request_id=request_id,
|
||||||
|
|||||||
@ -9,9 +9,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||||
MultiModalFieldElem, MultiModalKwargsItem,
|
|
||||||
PlaceholderRange)
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.utils import sha256, sha256_cbor_64bit
|
from vllm.utils import sha256, sha256_cbor_64bit
|
||||||
from vllm.v1.core.block_pool import BlockPool
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
@ -37,13 +35,7 @@ def make_request(
|
|||||||
if mm_positions is None:
|
if mm_positions is None:
|
||||||
mm_kwargs = None
|
mm_kwargs = None
|
||||||
else:
|
else:
|
||||||
mm_elem = MultiModalFieldElem(
|
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||||
modality="dummy_m",
|
|
||||||
key="dummy_k",
|
|
||||||
data=None,
|
|
||||||
field=MultiModalBatchedField(),
|
|
||||||
)
|
|
||||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
|
||||||
mm_kwargs = [mm_item] * len(mm_positions)
|
mm_kwargs = [mm_item] * len(mm_positions)
|
||||||
|
|
||||||
return Request(request_id=request_id,
|
return Request(request_id=request_id,
|
||||||
|
|||||||
@ -8,9 +8,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||||
MultiModalFieldElem, MultiModalKwargsItem,
|
|
||||||
PlaceholderRange)
|
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
@ -1328,13 +1326,7 @@ def create_requests_with_priority(
|
|||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
if mm_positions is not None:
|
if mm_positions is not None:
|
||||||
mm_position = mm_positions[i]
|
mm_position = mm_positions[i]
|
||||||
mm_elem = MultiModalFieldElem(
|
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||||
modality="dummy_m",
|
|
||||||
key="dummy_k",
|
|
||||||
data=None,
|
|
||||||
field=MultiModalBatchedField(),
|
|
||||||
)
|
|
||||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
|
||||||
mm_kwargs = [mm_item] * len(mm_position)
|
mm_kwargs = [mm_item] * len(mm_position)
|
||||||
else:
|
else:
|
||||||
mm_position = None
|
mm_position = None
|
||||||
|
|||||||
@ -6,9 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||||
MultiModalFieldElem, MultiModalKwargsItem,
|
|
||||||
PlaceholderRange)
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||||
init_none_hash)
|
init_none_hash)
|
||||||
@ -143,13 +141,7 @@ def create_requests(
|
|||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
if mm_positions is not None:
|
if mm_positions is not None:
|
||||||
mm_position = mm_positions[i]
|
mm_position = mm_positions[i]
|
||||||
mm_elem = MultiModalFieldElem(
|
mm_item = MultiModalKwargsItem.dummy("dummy_m")
|
||||||
modality="dummy_m",
|
|
||||||
key="dummy_k",
|
|
||||||
data=None,
|
|
||||||
field=MultiModalBatchedField(),
|
|
||||||
)
|
|
||||||
mm_item = MultiModalKwargsItem.from_elems([mm_elem])
|
|
||||||
mm_kwargs = [mm_item] * len(mm_position)
|
mm_kwargs = [mm_item] * len(mm_position)
|
||||||
mm_hashes = ["hash"] * len(mm_position)
|
mm_hashes = ["hash"] * len(mm_position)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -4,7 +4,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass, replace
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||||
@ -218,7 +218,7 @@ class MultiModalFieldElem:
|
|||||||
i.e. the name of the keyword argument to be passed to the model.
|
i.e. the name of the keyword argument to be passed to the model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: Optional[NestedTensors]
|
data: NestedTensors
|
||||||
"""
|
"""
|
||||||
The tensor data of this field in
|
The tensor data of this field in
|
||||||
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
[`MultiModalKwargs`][vllm.multimodal.inputs.MultiModalKwargs],
|
||||||
@ -315,13 +315,8 @@ class BaseMultiModalField(ABC):
|
|||||||
if len(set(field_types)) > 1:
|
if len(set(field_types)) > 1:
|
||||||
raise ValueError(f"Cannot merge different {field_types=}")
|
raise ValueError(f"Cannot merge different {field_types=}")
|
||||||
|
|
||||||
validated_data = list[NestedTensors]()
|
batch = [elem.data for elem in elems]
|
||||||
for i, elem in enumerate(elems):
|
return self._reduce_data(batch, pin_memory=pin_memory)
|
||||||
assert elem.data is not None, (
|
|
||||||
f"Cannot merge with empty `elems[{i}]`")
|
|
||||||
validated_data.append(elem.data)
|
|
||||||
|
|
||||||
return self._reduce_data(validated_data, pin_memory=pin_memory)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -643,6 +638,17 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
[`MultiModalDataItems`][vllm.multimodal.parse.MultiModalDataItems].
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dummy(modality: str):
|
||||||
|
"""Convenience class for testing."""
|
||||||
|
mm_elem = MultiModalFieldElem(
|
||||||
|
modality=modality,
|
||||||
|
key="dummy",
|
||||||
|
data=torch.empty(1),
|
||||||
|
field=MultiModalSharedField(1),
|
||||||
|
)
|
||||||
|
return MultiModalKwargsItem.from_elems([mm_elem])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||||
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||||
@ -654,46 +660,12 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||||
self._modality = next(iter(modalities))
|
self._modality = next(iter(modalities))
|
||||||
|
|
||||||
self._is_empty = any(elem.data is None for elem in self.values())
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def modality(self) -> str:
|
def modality(self) -> str:
|
||||||
return self._modality
|
return self._modality
|
||||||
|
|
||||||
@property
|
def get_data(self) -> Mapping[str, NestedTensors]:
|
||||||
def is_empty(self) -> bool:
|
return {key: elem.data for key, elem in self.items()}
|
||||||
return self._is_empty
|
|
||||||
|
|
||||||
def get_data(self) -> Optional[Mapping[str, NestedTensors]]:
|
|
||||||
if self._is_empty:
|
|
||||||
return None
|
|
||||||
|
|
||||||
out_data = dict[str, NestedTensors]()
|
|
||||||
for key, elem in self.items():
|
|
||||||
assert elem.data is not None, (
|
|
||||||
f"Cannot get data of empty `elem[{key!r}]`")
|
|
||||||
out_data[key] = elem.data
|
|
||||||
|
|
||||||
return out_data
|
|
||||||
|
|
||||||
def require_data(self) -> Mapping[str, NestedTensors]:
|
|
||||||
if (data := self.get_data()) is None:
|
|
||||||
raise RuntimeError("Cannot get data of empty item")
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
# These methods create a new item to avoid mutating cached items in place
|
|
||||||
def with_data(self, data: Mapping[str, NestedTensors]):
|
|
||||||
return MultiModalKwargsItem({
|
|
||||||
key: replace(elem, data=data[key])
|
|
||||||
for key, elem in self.items()
|
|
||||||
})
|
|
||||||
|
|
||||||
def without_data(self):
|
|
||||||
return MultiModalKwargsItem({
|
|
||||||
key: replace(elem, data=None)
|
|
||||||
for key, elem in self.items()
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# NOTE: UserDict is for V0 compatibility.
|
# NOTE: UserDict is for V0 compatibility.
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -47,7 +48,7 @@ class EngineCoreRequest(
|
|||||||
|
|
||||||
request_id: str
|
request_id: str
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_kwargs: Optional[list[MultiModalKwargsItem]]
|
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
|
||||||
mm_hashes: Optional[list[str]]
|
mm_hashes: Optional[list[str]]
|
||||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||||
sampling_params: Optional[SamplingParams]
|
sampling_params: Optional[SamplingParams]
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from collections.abc import Mapping
|
from collections.abc import Sequence
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from vllm.multimodal import MultiModalRegistry
|
from vllm.multimodal import MultiModalRegistry
|
||||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||||
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
|
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
@ -58,21 +59,21 @@ class MultiModalInputCacheClient:
|
|||||||
|
|
||||||
def get_and_update(
|
def get_and_update(
|
||||||
self,
|
self,
|
||||||
mm_kwargs: list[MultiModalKwargsItem],
|
mm_kwargs: Sequence[MultiModalKwargsItem],
|
||||||
mm_hashes: list[str],
|
mm_hashes: list[str],
|
||||||
) -> list[MultiModalKwargsItem]:
|
) -> list[Optional[MultiModalKwargsItem]]:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return mm_kwargs
|
return list(mm_kwargs)
|
||||||
|
|
||||||
assert len(mm_kwargs) == len(mm_hashes)
|
assert len(mm_kwargs) == len(mm_hashes)
|
||||||
|
|
||||||
out_mm_items = list[MultiModalKwargsItem]()
|
out_mm_items = list[Optional[MultiModalKwargsItem]]()
|
||||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||||
if self.mm_cache.get(mm_hash) is not None:
|
if self.mm_cache.get(mm_hash) is not None:
|
||||||
out_mm_items.append(mm_item.without_data())
|
out_mm_items.append(None)
|
||||||
else:
|
else:
|
||||||
self.mm_cache[mm_hash] = \
|
self.mm_cache[mm_hash] = \
|
||||||
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
|
MultiModalCacheItemMetadata.wraps(mm_item)
|
||||||
out_mm_items.append(mm_item)
|
out_mm_items.append(mm_item)
|
||||||
|
|
||||||
return out_mm_items
|
return out_mm_items
|
||||||
@ -91,25 +92,27 @@ class MultiModalInputCacheServer:
|
|||||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||||
model_config.get_mm_input_cache_gb(),
|
model_config.get_mm_input_cache_gb(),
|
||||||
Mapping[str, NestedTensors],
|
MultiModalKwargsItem,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_and_update(
|
def get_and_update(
|
||||||
self,
|
self,
|
||||||
mm_kwargs: list[MultiModalKwargsItem],
|
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
|
||||||
mm_hashes: list[str],
|
mm_hashes: list[str],
|
||||||
) -> list[MultiModalKwargsItem]:
|
) -> list[MultiModalKwargsItem]:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
return mm_kwargs
|
mm_kwargs_lst = list(mm_kwargs)
|
||||||
|
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
|
||||||
|
return mm_kwargs_lst
|
||||||
|
|
||||||
assert len(mm_kwargs) == len(mm_hashes)
|
assert len(mm_kwargs) == len(mm_hashes)
|
||||||
|
|
||||||
out_mm_items = list[MultiModalKwargsItem]()
|
out_mm_items = list[MultiModalKwargsItem]()
|
||||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||||
if (mm_data := mm_item.get_data()) is None:
|
if mm_item is None:
|
||||||
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
|
out_mm_items.append(self.mm_cache[mm_hash])
|
||||||
else:
|
else:
|
||||||
self.mm_cache[mm_hash] = mm_data
|
self.mm_cache[mm_hash] = mm_item
|
||||||
out_mm_items.append(mm_item)
|
out_mm_items.append(mm_item)
|
||||||
|
|
||||||
return out_mm_items
|
return out_mm_items
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
|||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
from vllm.utils import is_list_of
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||||
from vllm.v1.structured_output.backend_guidance import (
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
@ -295,7 +296,7 @@ class Processor:
|
|||||||
pooling_params = params.clone()
|
pooling_params = params.clone()
|
||||||
|
|
||||||
# Multimodal related.
|
# Multimodal related.
|
||||||
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
|
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
|
||||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||||
sorted_mm_hashes: Optional[list[str]] = None
|
sorted_mm_hashes: Optional[list[str]] = None
|
||||||
if decoder_inputs["type"] == "multimodal":
|
if decoder_inputs["type"] == "multimodal":
|
||||||
@ -308,7 +309,7 @@ class Processor:
|
|||||||
# in the input sequence.
|
# in the input sequence.
|
||||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||||
|
|
||||||
sorted_mm_inputs = [
|
orig_sorted_mm_inputs = [
|
||||||
decoder_mm_inputs.get_item(modality, idx)
|
decoder_mm_inputs.get_item(modality, idx)
|
||||||
for modality, idx in sorted_mm_idxs
|
for modality, idx in sorted_mm_idxs
|
||||||
]
|
]
|
||||||
@ -323,9 +324,12 @@ class Processor:
|
|||||||
|
|
||||||
if sorted_mm_hashes is not None:
|
if sorted_mm_hashes is not None:
|
||||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||||
sorted_mm_inputs,
|
orig_sorted_mm_inputs,
|
||||||
sorted_mm_hashes,
|
sorted_mm_hashes,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem)
|
||||||
|
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||||
|
|
||||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
|
|||||||
@ -125,14 +125,17 @@ class Request:
|
|||||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||||
) -> "Request":
|
) -> "Request":
|
||||||
if request.mm_kwargs is not None:
|
if request.mm_kwargs is not None:
|
||||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
mm_kwargs_lst = list(request.mm_kwargs)
|
||||||
|
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
|
||||||
"mm_kwargs was not updated in EngineCore.add_request")
|
"mm_kwargs was not updated in EngineCore.add_request")
|
||||||
|
else:
|
||||||
|
mm_kwargs_lst = None
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
client_index=request.client_index,
|
client_index=request.client_index,
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
multi_modal_kwargs=request.mm_kwargs,
|
multi_modal_kwargs=mm_kwargs_lst,
|
||||||
multi_modal_hashes=request.mm_hashes,
|
multi_modal_hashes=request.mm_hashes,
|
||||||
multi_modal_placeholders=request.mm_placeholders,
|
multi_modal_placeholders=request.mm_placeholders,
|
||||||
sampling_params=request.sampling_params,
|
sampling_params=request.sampling_params,
|
||||||
|
|||||||
@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
second_per_grid_ts = []
|
second_per_grid_ts = []
|
||||||
audio_feature_lengths = []
|
audio_feature_lengths = []
|
||||||
use_audio_in_video = False
|
use_audio_in_video = False
|
||||||
for item in self.requests[req_id].mm_kwargs:
|
for mm_item in self.requests[req_id].mm_kwargs:
|
||||||
mm_input = item.require_data()
|
mm_input = mm_item.get_data()
|
||||||
if mm_input.get("image_grid_thw") is not None:
|
if mm_input.get("image_grid_thw") is not None:
|
||||||
image_grid_thw.append(
|
image_grid_thw.append(
|
||||||
mm_input["image_grid_thw"].tolist())
|
mm_input["image_grid_thw"].tolist())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user