mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:25:01 +08:00
[Model] Define merge_by_field_config MM interface (#25676)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
b8d9e4a326
commit
0ea80c87d9
@ -19,6 +19,8 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
supports_multimodal)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext)
|
||||
@ -88,6 +90,7 @@ def resize_mm_data(
|
||||
|
||||
|
||||
def create_batched_mm_kwargs(
|
||||
model_cls: type[SupportsMultiModal],
|
||||
model_config: ModelConfig,
|
||||
processor: BaseMultiModalProcessor,
|
||||
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
|
||||
@ -127,16 +130,22 @@ def create_batched_mm_kwargs(
|
||||
mm_data=resized_mm_data,
|
||||
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=processor_inputs.tokenization_kwargs,
|
||||
)["mm_kwargs"]
|
||||
)["mm_kwargs"].require_data()
|
||||
items = [
|
||||
item for modality in supported_mm_limits
|
||||
for item in mm_kwargs[modality]
|
||||
]
|
||||
return group_mm_kwargs_by_modality(items)
|
||||
return group_mm_kwargs_by_modality(
|
||||
items,
|
||||
merge_by_field_config=model_cls.merge_by_field_config,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig):
|
||||
def initialize_dummy_model(
|
||||
model_cls: type[nn.Module],
|
||||
model_config: ModelConfig,
|
||||
):
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
init_distributed_environment(
|
||||
world_size=1,
|
||||
@ -198,8 +207,12 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
hf_overrides=hf_overrides_fn,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
assert supports_multimodal(model_cls)
|
||||
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
|
||||
inputs_parse_methods = []
|
||||
@ -228,7 +241,7 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
|
||||
|
||||
with initialize_dummy_model(model_cls, model_config) as model:
|
||||
for modality, _, mm_kwargs in create_batched_mm_kwargs(
|
||||
model_config, processor):
|
||||
model_cls, model_config, processor):
|
||||
for method_name in inputs_parse_methods:
|
||||
print(f"Testing `{method_name}` with modality={modality} "
|
||||
f"and mm_kwargs{list(mm_kwargs.keys())}")
|
||||
|
||||
@ -63,13 +63,12 @@ ConvertType = Literal["none", "embed", "classify", "reward"]
|
||||
ConvertOption = Literal["auto", ConvertType]
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription", "draft"]
|
||||
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
|
||||
"classify", "reward", "draft"]
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits",
|
||||
"processed_logprobs"]
|
||||
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
||||
HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
|
||||
|
||||
@ -64,6 +64,12 @@ class SupportsMultiModal(Protocol):
|
||||
`multimodal_config.mm_encoder_tp_mode="data"`.
|
||||
"""
|
||||
|
||||
merge_by_field_config: ClassVar[bool] = False
|
||||
"""
|
||||
A flag that indicates which implementation of
|
||||
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@ -40,7 +40,8 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
is_mixture_of_experts,
|
||||
supports_eagle3,
|
||||
supports_mrope,
|
||||
supports_transcription)
|
||||
@ -777,11 +778,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
mm_kwargs.append(feature.data)
|
||||
|
||||
# Input all modalities at once
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
mm_kwargs_combined: BatchedTensorInputs = {}
|
||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
mm_kwargs_combined.update(mm_kwargs_group)
|
||||
|
||||
@ -1525,11 +1528,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# in the same batch while still being able to benefit from batching
|
||||
# multimodal inputs. The proper solution should be reordering the
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_outputs = []
|
||||
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@ -1538,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# 2. A list or tuple (length: num_items) of tensors, each of shape
|
||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||
# depending on the input multimodal items.
|
||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||
curr_group_outputs = model.get_multimodal_embeddings(
|
||||
**mm_kwargs_group)
|
||||
|
||||
sanity_check_mm_encoder_outputs(
|
||||
@ -1623,11 +1628,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
return {}
|
||||
|
||||
# Group MM kwargs by modality and extract features
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_features = {}
|
||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
# Add the grouped features to encoder_features dict
|
||||
# This allows the model to receive them as kwargs (e.g.,
|
||||
@ -2839,11 +2846,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_mm_item = dummy_mm_data[modality][0]
|
||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
return next(mm_kwargs_group
|
||||
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
))
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
@ -30,7 +30,8 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.model_executor.model_loader.tpu import TPUModelLoader
|
||||
from vllm.model_executor.models.interfaces import supports_transcription
|
||||
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
|
||||
supports_transcription)
|
||||
from vllm.model_executor.models.interfaces_base import (
|
||||
is_pooling_model, is_text_generation_model)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -834,11 +835,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# in the same batch while still being able to benefit from batching
|
||||
# multimodal inputs. The proper solution should be reordering the
|
||||
# encoder outputs.
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
encoder_outputs = []
|
||||
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
|
||||
mm_kwargs,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
):
|
||||
# Run the encoder.
|
||||
# `curr_group_outputs` is either of the following:
|
||||
@ -848,7 +851,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# (feature_size, hidden_size) in case the feature size is dynamic
|
||||
# depending on the input multimodal items.
|
||||
torch_xla.sync(wait=False)
|
||||
curr_group_outputs = self.model.get_multimodal_embeddings(
|
||||
curr_group_outputs = model.get_multimodal_embeddings(
|
||||
**mm_kwargs_group)
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
@ -1805,11 +1808,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dummy_mm_item = dummy_mm_data[modality][0]
|
||||
dummy_mm_items = [dummy_mm_item] * max_items_per_batch
|
||||
|
||||
model = cast(SupportsMultiModal, self.model)
|
||||
return next(grouped_mm_kwargs
|
||||
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
|
||||
dummy_mm_items,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
merge_by_field_config=model.merge_by_field_config,
|
||||
))
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user