mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:34:57 +08:00
[Core] Cast multimodal input in hf processor (#18862)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
4555143ea7
commit
1409ef9134
@ -4,9 +4,12 @@ from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import resolve_mm_processor_kwargs
|
||||
@ -21,6 +24,8 @@ _T = TypeVar("_T")
|
||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
@ -134,7 +139,7 @@ class InputProcessingContext(InputContext):
|
||||
hf_processor: ProcessorMixin,
|
||||
data: Mapping[str, object],
|
||||
kwargs: Mapping[str, object] = {},
|
||||
) -> BatchFeature:
|
||||
) -> Union[BatchFeature, JSONTree]:
|
||||
"""
|
||||
Call `hf_processor` on the prompt `data`
|
||||
(text, image, audio...) with configurable options `kwargs`.
|
||||
@ -154,8 +159,25 @@ class InputProcessingContext(InputContext):
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
|
||||
def maybe_cast_dtype(x):
|
||||
# This mimics the behavior of transformers.BatchFeature
|
||||
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
||||
return x.to(dtype=self.model_config.dtype)
|
||||
return x
|
||||
|
||||
try:
|
||||
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||
# this emulates output.to(dtype=self.model_config.dtype)
|
||||
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||
if isinstance(output, BatchFeature):
|
||||
return BatchFeature(cast_output)
|
||||
|
||||
logger.warning_once(
|
||||
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
||||
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||
"implementing custom processors.")
|
||||
return cast_output
|
||||
|
||||
except Exception as exc:
|
||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
||||
f"on data={data} with kwargs={merged_kwargs}")
|
||||
|
||||
@ -747,17 +747,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||
batched_inputs: BatchedTensorInputs,
|
||||
*,
|
||||
device: torch.types.Device,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> BatchedTensorInputs:
|
||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||
|
||||
def maybe_cast_dtype(x: torch.Tensor):
|
||||
# This mimics the behavior of transformers.BatchFeature
|
||||
return x.to(dtype=dtype) if x.is_floating_point() else x
|
||||
|
||||
json_mapped = json_map_leaves(
|
||||
# NOTE: Cast the dtype before sending it to device
|
||||
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
|
||||
lambda x: x.to(device=device, non_blocking=True),
|
||||
json_inputs,
|
||||
)
|
||||
|
||||
|
||||
@ -297,7 +297,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_runner.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**model_execute_kwargs,
|
||||
|
||||
@ -957,7 +957,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@ -1951,7 +1950,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
[dummy_mm_kwargs] * max_num_mm_items)
|
||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
@ -718,7 +718,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||
batched_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@ -1560,7 +1559,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
batch_size)
|
||||
return MultiModalKwargs.as_kwargs(
|
||||
batched_dummy_mm_inputs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
@ -300,7 +300,6 @@ class CPUEncoderDecoderModelRunner(
|
||||
model_input.encoder_input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
"intermediate_tensors":
|
||||
|
||||
@ -630,7 +630,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
execute_model_kwargs = {}
|
||||
|
||||
@ -53,7 +53,6 @@ class CPUPoolingModelRunner(
|
||||
model_input.input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
|
||||
@ -205,7 +205,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**seqlen_agnostic_kwargs,
|
||||
|
||||
@ -1848,7 +1848,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**seqlen_agnostic_kwargs,
|
||||
|
||||
@ -73,7 +73,6 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
@ -52,7 +52,6 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
||||
sampling_params=sampling_params,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
@ -395,7 +395,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
adapter_ids=model_input.adapter_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
@ -408,7 +407,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
input_block_ids=model_input.input_block_ids,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
@ -122,7 +122,6 @@ class PoolingModelRunner(
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
multi_modal_kwargs,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
|
||||
@ -565,7 +565,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user