[Model] Define merge_by_field_config MM interface (#25676)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-09-26 01:13:07 +08:00 committed by GitHub
parent b8d9e4a326
commit 0ea80c87d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 44 additions and 12 deletions

View File

@ -19,6 +19,8 @@ from vllm.distributed import (cleanup_dist_env_and_memory,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel) initialize_model_parallel)
from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.interfaces import (SupportsMultiModal,
supports_multimodal)
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensorInputs
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext) InputProcessingContext)
@ -88,6 +90,7 @@ def resize_mm_data(
def create_batched_mm_kwargs( def create_batched_mm_kwargs(
model_cls: type[SupportsMultiModal],
model_config: ModelConfig, model_config: ModelConfig,
processor: BaseMultiModalProcessor, processor: BaseMultiModalProcessor,
size_factors: tuple[float, ...] = (1.0, 0.5, 0.25), size_factors: tuple[float, ...] = (1.0, 0.5, 0.25),
@ -127,16 +130,22 @@ def create_batched_mm_kwargs(
mm_data=resized_mm_data, mm_data=resized_mm_data,
hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs,
tokenization_kwargs=processor_inputs.tokenization_kwargs, tokenization_kwargs=processor_inputs.tokenization_kwargs,
)["mm_kwargs"] )["mm_kwargs"].require_data()
items = [ items = [
item for modality in supported_mm_limits item for modality in supported_mm_limits
for item in mm_kwargs[modality] for item in mm_kwargs[modality]
] ]
return group_mm_kwargs_by_modality(items) return group_mm_kwargs_by_modality(
items,
merge_by_field_config=model_cls.merge_by_field_config,
)
@contextmanager @contextmanager
def initialize_dummy_model(model_cls: nn.Module, model_config: ModelConfig): def initialize_dummy_model(
model_cls: type[nn.Module],
model_config: ModelConfig,
):
temp_file = tempfile.mkstemp()[1] temp_file = tempfile.mkstemp()[1]
init_distributed_environment( init_distributed_environment(
world_size=1, world_size=1,
@ -198,8 +207,12 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
hf_overrides=hf_overrides_fn, hf_overrides=hf_overrides_fn,
skip_tokenizer_init=model_info.skip_tokenizer_init, skip_tokenizer_init=model_info.skip_tokenizer_init,
enforce_eager=model_info.enforce_eager, enforce_eager=model_info.enforce_eager,
dtype=model_info.dtype) dtype=model_info.dtype,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
assert supports_multimodal(model_cls)
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls] factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
inputs_parse_methods = [] inputs_parse_methods = []
@ -228,7 +241,7 @@ def test_model_tensor_schema(model_arch: str, model_id: str):
with initialize_dummy_model(model_cls, model_config) as model: with initialize_dummy_model(model_cls, model_config) as model:
for modality, _, mm_kwargs in create_batched_mm_kwargs( for modality, _, mm_kwargs in create_batched_mm_kwargs(
model_config, processor): model_cls, model_config, processor):
for method_name in inputs_parse_methods: for method_name in inputs_parse_methods:
print(f"Testing `{method_name}` with modality={modality} " print(f"Testing `{method_name}` with modality={modality} "
f"and mm_kwargs{list(mm_kwargs.keys())}") f"and mm_kwargs{list(mm_kwargs.keys())}")

View File

@ -63,13 +63,12 @@ ConvertType = Literal["none", "embed", "classify", "reward"]
ConvertOption = Literal["auto", ConvertType] ConvertOption = Literal["auto", ConvertType]
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify", TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription", "draft"] "score", "reward", "transcription", "draft"]
_ResolvedTask = Literal["generate", "transcription", "encode", "embed",
"classify", "reward", "draft"]
TokenizerMode = Literal["auto", "slow", "mistral", "custom"] TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"] ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits", LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits",
"processed_logprobs"] "processed_logprobs"]
HfOverrides = Union[dict[str, Any], Callable[[type], type]] HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"] ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = { _RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {

View File

@ -64,6 +64,12 @@ class SupportsMultiModal(Protocol):
`multimodal_config.mm_encoder_tp_mode="data"`. `multimodal_config.mm_encoder_tp_mode="data"`.
""" """
merge_by_field_config: ClassVar[bool] = False
"""
A flag that indicates which implementation of
`vllm.multimodal.utils.group_mm_kwargs_by_modality` to use.
"""
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
""" """

View File

@ -40,7 +40,8 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.interfaces import (is_mixture_of_experts, from vllm.model_executor.models.interfaces import (SupportsMultiModal,
is_mixture_of_experts,
supports_eagle3, supports_eagle3,
supports_mrope, supports_mrope,
supports_transcription) supports_transcription)
@ -777,11 +778,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
mm_kwargs.append(feature.data) mm_kwargs.append(feature.data)
# Input all modalities at once # Input all modalities at once
model = cast(SupportsMultiModal, self.model)
mm_kwargs_combined: BatchedTensorInputs = {} mm_kwargs_combined: BatchedTensorInputs = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
): ):
mm_kwargs_combined.update(mm_kwargs_group) mm_kwargs_combined.update(mm_kwargs_group)
@ -1525,11 +1528,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model)
encoder_outputs = [] encoder_outputs = []
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
): ):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
@ -1538,7 +1543,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# 2. A list or tuple (length: num_items) of tensors, each of shape # 2. A list or tuple (length: num_items) of tensors, each of shape
# (feature_size, hidden_size) in case the feature size is dynamic # (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items. # depending on the input multimodal items.
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = model.get_multimodal_embeddings(
**mm_kwargs_group) **mm_kwargs_group)
sanity_check_mm_encoder_outputs( sanity_check_mm_encoder_outputs(
@ -1623,11 +1628,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return {} return {}
# Group MM kwargs by modality and extract features # Group MM kwargs by modality and extract features
model = cast(SupportsMultiModal, self.model)
encoder_features = {} encoder_features = {}
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
): ):
# Add the grouped features to encoder_features dict # Add the grouped features to encoder_features dict
# This allows the model to receive them as kwargs (e.g., # This allows the model to receive them as kwargs (e.g.,
@ -2839,11 +2846,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next(mm_kwargs_group return next(mm_kwargs_group
for _, _, mm_kwargs_group in group_mm_kwargs_by_modality( for _, _, mm_kwargs_group in group_mm_kwargs_by_modality(
dummy_mm_items, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
)) ))
@torch.inference_mode() @torch.inference_mode()

View File

@ -30,7 +30,8 @@ from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA from vllm.lora.layers import BaseLayerWithLoRA
from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.tpu import TPUModelLoader from vllm.model_executor.model_loader.tpu import TPUModelLoader
from vllm.model_executor.models.interfaces import supports_transcription from vllm.model_executor.models.interfaces import (SupportsMultiModal,
supports_transcription)
from vllm.model_executor.models.interfaces_base import ( from vllm.model_executor.models.interfaces_base import (
is_pooling_model, is_text_generation_model) is_pooling_model, is_text_generation_model)
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -834,11 +835,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# in the same batch while still being able to benefit from batching # in the same batch while still being able to benefit from batching
# multimodal inputs. The proper solution should be reordering the # multimodal inputs. The proper solution should be reordering the
# encoder outputs. # encoder outputs.
model = cast(SupportsMultiModal, self.model)
encoder_outputs = [] encoder_outputs = []
for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
mm_kwargs, mm_kwargs,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
): ):
# Run the encoder. # Run the encoder.
# `curr_group_outputs` is either of the following: # `curr_group_outputs` is either of the following:
@ -848,7 +851,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# (feature_size, hidden_size) in case the feature size is dynamic # (feature_size, hidden_size) in case the feature size is dynamic
# depending on the input multimodal items. # depending on the input multimodal items.
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
curr_group_outputs = self.model.get_multimodal_embeddings( curr_group_outputs = model.get_multimodal_embeddings(
**mm_kwargs_group) **mm_kwargs_group)
torch_xla.sync(wait=False) torch_xla.sync(wait=False)
@ -1805,11 +1808,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dummy_mm_item = dummy_mm_data[modality][0] dummy_mm_item = dummy_mm_data[modality][0]
dummy_mm_items = [dummy_mm_item] * max_items_per_batch dummy_mm_items = [dummy_mm_item] * max_items_per_batch
model = cast(SupportsMultiModal, self.model)
return next(grouped_mm_kwargs return next(grouped_mm_kwargs
for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality(
dummy_mm_items, dummy_mm_items,
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
merge_by_field_config=model.merge_by_field_config,
)) ))