mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 18:07:05 +08:00
[Core] Automatically cast multi-modal input dtype (#18756)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6b6d496114
commit
696259ca01
@ -210,9 +210,7 @@ class DeepseekVL2MultiModalProcessor(
|
|||||||
dict(prompt=prompt, **mm_data),
|
dict(prompt=prompt, **mm_data),
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
target_dtype = self.info.ctx.model_config.dtype
|
pixel_values = processed_outputs["pixel_values"]
|
||||||
pixel_values = processed_outputs.pop("pixel_values").to(
|
|
||||||
target_dtype)
|
|
||||||
# split pixel values into patches corresponding to each image
|
# split pixel values into patches corresponding to each image
|
||||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
||||||
patches_per_image = [
|
patches_per_image = [
|
||||||
|
|||||||
@ -263,11 +263,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
|||||||
mm_data,
|
mm_data,
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
if "pixel_values" in processed_outputs:
|
|
||||||
# Cast pixel values to model dtype already here,
|
|
||||||
# so we need to transfer less data to the GPU
|
|
||||||
processed_outputs["pixel_values"] = processed_outputs[
|
|
||||||
"pixel_values"].to(self.info.ctx.model_config.dtype)
|
|
||||||
|
|
||||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
||||||
if (images := mm_data.get("images")) is not None:
|
if (images := mm_data.get("images")) is not None:
|
||||||
|
|||||||
@ -746,11 +746,17 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
batched_inputs: BatchedTensorInputs,
|
batched_inputs: BatchedTensorInputs,
|
||||||
*,
|
*,
|
||||||
device: torch.types.Device,
|
device: torch.types.Device,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
) -> BatchedTensorInputs:
|
) -> BatchedTensorInputs:
|
||||||
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
json_inputs = cast(JSONTree[torch.Tensor], batched_inputs)
|
||||||
|
|
||||||
|
def maybe_cast_dtype(x: torch.Tensor):
|
||||||
|
# This mimics the behavior of transformers.BatchFeature
|
||||||
|
return x.to(dtype=dtype) if x.is_floating_point() else x
|
||||||
|
|
||||||
json_mapped = json_map_leaves(
|
json_mapped = json_map_leaves(
|
||||||
lambda x: x.to(device, non_blocking=True),
|
# NOTE: Cast the dtype before sending it to device
|
||||||
|
lambda x: maybe_cast_dtype(x).to(device=device, non_blocking=True),
|
||||||
json_inputs,
|
json_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -294,8 +294,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
multi_modal_kwargs,
|
||||||
|
dtype=self.model_runner.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
**model_execute_kwargs,
|
**model_execute_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -929,8 +929,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
device=self.device)
|
batched_mm_inputs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `curr_group_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
@ -1874,7 +1877,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
batched_dummy_mm_inputs = MultiModalKwargs.batch(
|
||||||
[dummy_mm_kwargs] * max_num_mm_items)
|
[dummy_mm_kwargs] * max_num_mm_items)
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
batched_dummy_mm_inputs, device=self.device)
|
batched_dummy_mm_inputs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Run multimodal encoder.
|
# Run multimodal encoder.
|
||||||
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
|
||||||
|
|||||||
@ -652,8 +652,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
encoder_outputs = []
|
encoder_outputs = []
|
||||||
for grouped_mm_inputs in grouped_mm_inputs_list:
|
for grouped_mm_inputs in grouped_mm_inputs_list:
|
||||||
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
|
||||||
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
|
batched_mm_inputs = MultiModalKwargs.as_kwargs(
|
||||||
device=self.device)
|
batched_mm_inputs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Run the encoder.
|
# Run the encoder.
|
||||||
# `curr_group_outputs` is either of the following:
|
# `curr_group_outputs` is either of the following:
|
||||||
@ -1435,8 +1438,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
|
||||||
batch_size)
|
batch_size)
|
||||||
return MultiModalKwargs.as_kwargs(batched_dummy_mm_inputs,
|
return MultiModalKwargs.as_kwargs(
|
||||||
device=self.device)
|
batched_dummy_mm_inputs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]:
|
||||||
|
|||||||
@ -297,8 +297,11 @@ class CPUEncoderDecoderModelRunner(
|
|||||||
model_input.encoder_input_tokens,
|
model_input.encoder_input_tokens,
|
||||||
"encoder_positions":
|
"encoder_positions":
|
||||||
model_input.encoder_input_positions,
|
model_input.encoder_input_positions,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
model_input.multi_modal_kwargs or {},
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
"intermediate_tensors":
|
"intermediate_tensors":
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -628,7 +628,10 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
|||||||
multimodal_kwargs = {}
|
multimodal_kwargs = {}
|
||||||
if model_input.multi_modal_kwargs is not None:
|
if model_input.multi_modal_kwargs is not None:
|
||||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||||
model_input.multi_modal_kwargs, device=self.device)
|
model_input.multi_modal_kwargs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
execute_model_kwargs = {}
|
execute_model_kwargs = {}
|
||||||
if previous_hidden_states is not None:
|
if previous_hidden_states is not None:
|
||||||
execute_model_kwargs.update(
|
execute_model_kwargs.update(
|
||||||
|
|||||||
@ -50,8 +50,11 @@ class CPUPoolingModelRunner(
|
|||||||
model_input.input_tokens,
|
model_input.input_tokens,
|
||||||
"positions":
|
"positions":
|
||||||
model_input.input_positions,
|
model_input.input_positions,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
model_input.multi_modal_kwargs or {},
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
**cross_enc_kwargs,
|
**cross_enc_kwargs,
|
||||||
"intermediate_tensors":
|
"intermediate_tensors":
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
|
|||||||
@ -202,9 +202,13 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
|
|||||||
encoder_input_ids=model_input.encoder_input_tokens,
|
encoder_input_ids=model_input.encoder_input_tokens,
|
||||||
encoder_positions=model_input.encoder_input_positions,
|
encoder_positions=model_input.encoder_input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
multi_modal_kwargs,
|
||||||
**seqlen_agnostic_kwargs)
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
**seqlen_agnostic_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||||
model_input.sampling_metadata)
|
model_input.sampling_metadata)
|
||||||
|
|||||||
@ -1845,8 +1845,11 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
|||||||
inputs_embeds=model_input.inputs_embeds,
|
inputs_embeds=model_input.inputs_embeds,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
multi_modal_kwargs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
**seqlen_agnostic_kwargs,
|
**seqlen_agnostic_kwargs,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -70,8 +70,11 @@ class MultiStepNeuronModelRunner(NeuronModelRunner):
|
|||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
model_input.multi_modal_kwargs or {},
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
|
|||||||
@ -49,8 +49,11 @@ class MultiStepNeuronxDistributedModelRunner(NeuronxDistributedModelRunner):
|
|||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
model_input.multi_modal_kwargs or {},
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
output = self.model.sample(
|
output = self.model.sample(
|
||||||
|
|||||||
@ -378,9 +378,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
**MultiModalKwargs.as_kwargs(
|
||||||
or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
device=self.device),
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
elif current_platform.use_transformers_neuronx():
|
elif current_platform.use_transformers_neuronx():
|
||||||
# [TODO] validate on-device sampling
|
# [TODO] validate on-device sampling
|
||||||
@ -389,9 +391,11 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
input_block_ids=model_input.input_block_ids,
|
input_block_ids=model_input.input_block_ids,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
**MultiModalKwargs.as_kwargs(
|
||||||
or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
device=self.device),
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits only if the on-device sampling is turned off as
|
# Compute the logits only if the on-device sampling is turned off as
|
||||||
|
|||||||
@ -119,10 +119,14 @@ class PoolingModelRunner(
|
|||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(multi_modal_kwargs,
|
**MultiModalKwargs.as_kwargs(
|
||||||
device=self.device),
|
multi_modal_kwargs,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
**cross_enc_kwargs,
|
**cross_enc_kwargs,
|
||||||
**seqlen_agnostic_kwargs)
|
**seqlen_agnostic_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
if (self.observability_config is not None
|
if (self.observability_config is not None
|
||||||
and self.observability_config.collect_model_forward_time):
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
|||||||
@ -562,9 +562,12 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
positions=model_input.input_positions,
|
positions=model_input.input_positions,
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs
|
**MultiModalKwargs.as_kwargs(
|
||||||
or {},
|
model_input.multi_modal_kwargs or {},
|
||||||
device=self.device))
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
# Compute the logits in the last pipeline stage.
|
# Compute the logits in the last pipeline stage.
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
return hidden_or_intermediate_states
|
return hidden_or_intermediate_states
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user