mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 07:34:58 +08:00
[Bugfix] Fixes Phi3v & Ultravox Multimodal EmbeddingInputs (#8979)
This commit is contained in:
parent
dac914b0d6
commit
53b3a33027
@ -467,9 +467,10 @@ def input_processor_for_phi3v(ctx: InputContext,
|
|||||||
input_height=h,
|
input_height=h,
|
||||||
num_crops=num_crops))
|
num_crops=num_crops))
|
||||||
elif isinstance(image_data, torch.Tensor):
|
elif isinstance(image_data, torch.Tensor):
|
||||||
num_images, image_feature_size, hidden_size = image_data.shape
|
image_feature_size = [image_data.shape[0]]
|
||||||
|
image_data = [image_data]
|
||||||
elif is_list_of(image_data, torch.Tensor):
|
elif is_list_of(image_data, torch.Tensor):
|
||||||
image_feature_size = [item.shape[1] for item in image_data]
|
image_feature_size = [item.shape[0] for item in image_data]
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||||
|
|
||||||
@ -611,9 +612,6 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
image_sizes = kwargs.pop("image_sizes", None)
|
image_sizes = kwargs.pop("image_sizes", None)
|
||||||
image_embeds = kwargs.pop("image_embeds", None)
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
|
||||||
if pixel_values is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if pixel_values is None and image_embeds is None:
|
if pixel_values is None and image_embeds is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -650,7 +648,17 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
if image_input["type"] == "image_embeds":
|
||||||
return image_input["data"]
|
image_data = image_input["data"]
|
||||||
|
if is_list_of(image_data, torch.Tensor):
|
||||||
|
# it's already a list of tensors
|
||||||
|
return image_data
|
||||||
|
if len(image_data.shape) == 3:
|
||||||
|
# 3D tensor
|
||||||
|
return list(torch.unbind(image_data, dim=0))
|
||||||
|
raise ValueError(
|
||||||
|
"We expect batched 2D tensors;"
|
||||||
|
"this can be either a list of 2D tensors or a single 3D tensor."
|
||||||
|
)
|
||||||
|
|
||||||
assert self.vision_embed_tokens is not None
|
assert self.vision_embed_tokens is not None
|
||||||
image_embeds = self.vision_embed_tokens(image_input["data"],
|
image_embeds = self.vision_embed_tokens(image_input["data"],
|
||||||
|
|||||||
@ -38,6 +38,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
|
|||||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||||
SequenceData)
|
SequenceData)
|
||||||
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
|
|
||||||
@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
|||||||
if not isinstance(data, list):
|
if not isinstance(data, list):
|
||||||
data = [data]
|
data = [data]
|
||||||
|
|
||||||
|
# If the audio inputs are embeddings, no need for preprocessing
|
||||||
|
if is_list_of(data, torch.Tensor, check="all"):
|
||||||
|
return MultiModalInputs({"audio_embeds": data})
|
||||||
|
|
||||||
audio_features = []
|
audio_features = []
|
||||||
for audio_input in data:
|
for audio_input in data:
|
||||||
if not isinstance(audio_input, tuple):
|
if not isinstance(audio_input, tuple):
|
||||||
@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
|||||||
audios = [audios]
|
audios = [audios]
|
||||||
|
|
||||||
audio_token_counts = []
|
audio_token_counts = []
|
||||||
for audio_data, sample_rate in audios:
|
for audio in audios:
|
||||||
audio_length = audio_data.shape[0]
|
if isinstance(audio, torch.Tensor):
|
||||||
if sample_rate != feature_extractor.sampling_rate:
|
audio_num_tokens = audio.shape[1]
|
||||||
# Account for resampling.
|
audio_token_counts.append(audio_num_tokens)
|
||||||
adjustment = feature_extractor.sampling_rate / sample_rate
|
else:
|
||||||
audio_length = math.ceil(adjustment * audio_length)
|
audio_data, sample_rate = audio
|
||||||
|
audio_length = audio_data.shape[0]
|
||||||
|
if sample_rate != feature_extractor.sampling_rate:
|
||||||
|
# Account for resampling.
|
||||||
|
adjustment = feature_extractor.sampling_rate / sample_rate
|
||||||
|
audio_length = math.ceil(adjustment * audio_length)
|
||||||
|
|
||||||
feature_extractor_output_length = math.ceil(
|
feature_extractor_output_length = math.ceil(
|
||||||
(audio_length - (feature_extractor.hop_length - 1)) /
|
(audio_length - (feature_extractor.hop_length - 1)) /
|
||||||
feature_extractor.hop_length)
|
feature_extractor.hop_length)
|
||||||
|
|
||||||
uv_config = ctx.get_hf_config(UltravoxConfig)
|
uv_config = ctx.get_hf_config(UltravoxConfig)
|
||||||
audio_num_tokens = min(
|
audio_num_tokens = min(
|
||||||
max(
|
max(
|
||||||
1,
|
1,
|
||||||
math.ceil(feature_extractor_output_length /
|
math.ceil(feature_extractor_output_length /
|
||||||
(uv_config.stack_factor * 2))),
|
(uv_config.stack_factor * 2))),
|
||||||
get_ultravox_max_audio_tokens(ctx))
|
get_ultravox_max_audio_tokens(ctx))
|
||||||
audio_token_counts.append(audio_num_tokens)
|
audio_token_counts.append(audio_num_tokens)
|
||||||
|
|
||||||
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user