[Core] Cast multimodal input in hf processor (#18862)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger 2025-06-04 04:24:56 +01:00 committed by GitHub
parent 4555143ea7
commit 1409ef9134
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 25 additions and 25 deletions

View File

@ -4,9 +4,12 @@ from collections.abc import Mapping
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
import torch
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import TypeVar
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import resolve_mm_processor_kwargs
@ -21,6 +24,8 @@ _T = TypeVar("_T")
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
logger = init_logger(__name__)
@dataclass(frozen=True)
class InputContext:
@ -134,7 +139,7 @@ class InputProcessingContext(InputContext):
hf_processor: ProcessorMixin,
data: Mapping[str, object],
kwargs: Mapping[str, object] = {},
) -> BatchFeature:
) -> Union[BatchFeature, JSONTree]:
"""
Call `hf_processor` on the prompt `data`
(text, image, audio...) with configurable options `kwargs`.
@ -154,8 +159,25 @@ class InputProcessingContext(InputContext):
allow_var_kwargs=True,
)
def maybe_cast_dtype(x):
# This mimics the behavior of transformers.BatchFeature
if isinstance(x, torch.Tensor) and x.is_floating_point():
return x.to(dtype=self.model_config.dtype)
return x
try:
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
# this emulates output.to(dtype=self.model_config.dtype)
cast_output = json_map_leaves(maybe_cast_dtype, output)
if isinstance(output, BatchFeature):
return BatchFeature(cast_output)
logger.warning_once(
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
"Make sure to match the behaviour of `ProcessorMixin` when "
"implementing custom processors.")
return cast_output
except Exception as exc:
msg = (f"Failed to apply {type(hf_processor).__name__} "
f"on data={data} with kwargs={merged_kwargs}")

View File

@ -747,17 +747,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
batched_inputs: BatchedTensorInputs,
*,
device: torch.types.Device,
dtype: Optional[torch.dtype] = None,
) -> BatchedTensorInputs:
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
def maybe_cast_dtype(x: torch.Tensor):
# This mimics the behavior of transformers.BatchFeature
return x.to(dtype=dtype) if x.is_floating_point() else x
json_mapped = json_map_leaves(
# NOTE: Cast the dtype before sending it to device
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
lambda x: x.to(device=device, non_blocking=True),
json_inputs,
)

View File

@ -297,7 +297,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_runner.model_config.dtype,
device=self.device,
),
**model_execute_kwargs,

View File

@ -957,7 +957,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
@ -1951,7 +1950,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
[dummy_mm_kwargs] * max_num_mm_items)
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

View File

@ -718,7 +718,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(
batched_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)
@ -1560,7 +1559,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
batch_size)
return MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs,
dtype=self.model_config.dtype,
device=self.device,
)

View File

@ -300,7 +300,6 @@ class CPUEncoderDecoderModelRunner(
model_input.encoder_input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
"intermediate_tensors":

View File

@ -630,7 +630,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
if model_input.multi_modal_kwargs is not None:
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
)
execute_model_kwargs = {}

View File

@ -53,7 +53,6 @@ class CPUPoolingModelRunner(
model_input.input_positions,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,

View File

@ -205,7 +205,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,

View File

@ -1848,7 +1848,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**seqlen_agnostic_kwargs,

View File

@ -73,7 +73,6 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

View File

@ -52,7 +52,6 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
sampling_params=sampling_params,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

View File

@ -395,7 +395,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
adapter_ids=model_input.adapter_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)
@ -408,7 +407,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
input_block_ids=model_input.input_block_ids,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)

View File

@ -122,7 +122,6 @@ class PoolingModelRunner(
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
multi_modal_kwargs,
dtype=self.model_config.dtype,
device=self.device,
),
**cross_enc_kwargs,

View File

@ -565,7 +565,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
intermediate_tensors=intermediate_tensors,
**MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs or {},
dtype=self.model_config.dtype,
device=self.device,
),
)