mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:24:57 +08:00
[Core] Cast multimodal input in hf processor (#18862)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
parent
4555143ea7
commit
1409ef9134
@ -4,9 +4,12 @@ from collections.abc import Mapping
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
|
from vllm.jsontree import JSONTree, json_map_leaves
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import resolve_mm_processor_kwargs
|
from vllm.utils import resolve_mm_processor_kwargs
|
||||||
@ -21,6 +24,8 @@ _T = TypeVar("_T")
|
|||||||
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
_C = TypeVar("_C", bound=PretrainedConfig, default=PretrainedConfig)
|
||||||
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
_P = TypeVar("_P", bound=ProcessorMixin, default=ProcessorMixin)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class InputContext:
|
class InputContext:
|
||||||
@ -134,7 +139,7 @@ class InputProcessingContext(InputContext):
|
|||||||
hf_processor: ProcessorMixin,
|
hf_processor: ProcessorMixin,
|
||||||
data: Mapping[str, object],
|
data: Mapping[str, object],
|
||||||
kwargs: Mapping[str, object] = {},
|
kwargs: Mapping[str, object] = {},
|
||||||
) -> BatchFeature:
|
) -> Union[BatchFeature, JSONTree]:
|
||||||
"""
|
"""
|
||||||
Call `hf_processor` on the prompt `data`
|
Call `hf_processor` on the prompt `data`
|
||||||
(text, image, audio...) with configurable options `kwargs`.
|
(text, image, audio...) with configurable options `kwargs`.
|
||||||
@ -154,8 +159,25 @@ class InputProcessingContext(InputContext):
|
|||||||
allow_var_kwargs=True,
|
allow_var_kwargs=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def maybe_cast_dtype(x):
|
||||||
|
# This mimics the behavior of transformers.BatchFeature
|
||||||
|
if isinstance(x, torch.Tensor) and x.is_floating_point():
|
||||||
|
return x.to(dtype=self.model_config.dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
output = hf_processor(**data, **merged_kwargs, return_tensors="pt")
|
||||||
|
# this emulates output.to(dtype=self.model_config.dtype)
|
||||||
|
cast_output = json_map_leaves(maybe_cast_dtype, output)
|
||||||
|
if isinstance(output, BatchFeature):
|
||||||
|
return BatchFeature(cast_output)
|
||||||
|
|
||||||
|
logger.warning_once(
|
||||||
|
f"{type(hf_processor).__name__} did not return `BatchFeature`. "
|
||||||
|
"Make sure to match the behaviour of `ProcessorMixin` when "
|
||||||
|
"implementing custom processors.")
|
||||||
|
return cast_output
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
msg = (f"Failed to apply {type(hf_processor).__name__} "
|
||||||
f"on data={data} with kwargs={merged_kwargs}")
|
f"on data={data} with kwargs={merged_kwargs}")
|
||||||
|
|||||||
@ -747,17 +747,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
batched_inputs: BatchedTensorInputs,
|
batched_inputs: BatchedTensorInputs,
|
||||||
*,
|
*,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> BatchedTensorInputs:
|
) -> BatchedTensorInputs:
|
||||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||||
|
|
||||||
def maybe_cast_dtype(x: torch.Tensor):
|
|
||||||
# This mimics the behavior of transformers.BatchFeature
|
|
||||||
return x.to(dtype=dtype) if x.is_floating_point() else x
|
|
||||||
|
|
||||||
json_mapped = json_map_leaves(
|
json_mapped = json_map_leaves(
|
||||||
# NOTE: Cast the dtype before sending it to device
|
lambda x: x.to(device=device, non_blocking=True),
|
||||||
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
|
|
||||||
json_inputs,
|
json_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -297,7 +297,6 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
multi_modal_kwargs,
|
multi_modal_kwargs,
|
||||||
dtype=self.model_runner.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
**model_execute_kwargs,
|
**model_execute_kwargs,
|
||||||
|
|||||||
@ -957,7 +957,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
batched_mm_inputs,
|
batched_mm_inputs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1951,7 +1950,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
[dummy_mm_kwargs] * max_num_mm_items)
|
[dummy_mm_kwargs] * max_num_mm_items)
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
batched_dummy_mm_inputs,
|
batched_dummy_mm_inputs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -718,7 +718,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
batched_mm_inputs,
|
batched_mm_inputs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1560,7 +1559,6 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
batch_size)
|
batch_size)
|
||||||
return MultiModalKwargs.as_kwargs(
|
return MultiModalKwargs.as_kwargs(
|
||||||
batched_dummy_mm_inputs,
|
batched_dummy_mm_inputs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -300,7 +300,6 @@ class CPUEncoderDecoderModelRunner(
|
|||||||
model_input.encoder_input_positions,
|
model_input.encoder_input_positions,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
"intermediate_tensors":
|
"intermediate_tensors":
|
||||||
|
|||||||
@ -630,7 +630,6 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
if model_input.multi_modal_kwargs is not None:
|
if model_input.multi_modal_kwargs is not None:
|
||||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs,
|
model_input.multi_modal_kwargs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
execute_model_kwargs = {}
|
execute_model_kwargs = {}
|
||||||
|
|||||||
@ -53,7 +53,6 @@ class CPUPoolingModelRunner(
|
|||||||
model_input.input_positions,
|
model_input.input_positions,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
**cross_enc_kwargs,
|
**cross_enc_kwargs,
|
||||||
|
|||||||
@ -205,7 +205,6 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
multi_modal_kwargs,
|
multi_modal_kwargs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
**seqlen_agnostic_kwargs,
|
**seqlen_agnostic_kwargs,
|
||||||
|
|||||||
@ -1848,7 +1848,6 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
multi_modal_kwargs,
|
multi_modal_kwargs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
**seqlen_agnostic_kwargs,
|
**seqlen_agnostic_kwargs,
|
||||||
|
|||||||
@ -73,7 +73,6 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
|
|||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -52,7 +52,6 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -395,7 +395,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
adapter_ids=model_input.adapter_ids,
|
adapter_ids=model_input.adapter_ids,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -408,7 +407,6 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -122,7 +122,6 @@ class PoolingModelRunner(
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
multi_modal_kwargs,
|
multi_modal_kwargs,
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
**cross_enc_kwargs,
|
**cross_enc_kwargs,
|
||||||
|
|||||||
@ -565,7 +565,6 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(
|
**MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
dtype=self.model_config.dtype,
|
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user