mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-29 02:27:14 +08:00
[Model] Use merge_by_field_config for MM models (Qwen series) (#27546)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
63b22e0dbb
commit
cbd5e07a51
@ -126,12 +126,12 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
|
||||
type: Literal["audio_features"]
|
||||
input_features: Annotated[
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("nmb", "tsl"),
|
||||
TensorShape("nmb", "tsl", dynamic_dims={"tsl"}),
|
||||
]
|
||||
|
||||
feature_attention_mask: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("na", "msl"),
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("na", "msl", dynamic_dims={"msl"}),
|
||||
]
|
||||
|
||||
|
||||
@ -651,18 +651,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
|
||||
|
||||
class Qwen2_5OmniConditionalGenerationMixin:
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str, dim: int = 0
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if dim == 0:
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
return torch.concat(list(mm_input), dim=dim)
|
||||
else:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5OmniAudioFeatureInputs | None:
|
||||
@ -671,18 +659,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
|
||||
if input_audio_features is None:
|
||||
return None
|
||||
input_audio_features = self._validate_and_reshape_mm_tensor(
|
||||
input_audio_features, "input_audio_features", dim=1
|
||||
)
|
||||
if feature_attention_mask is not None:
|
||||
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
||||
feature_attention_mask, "feature_attention_mask"
|
||||
)
|
||||
if not isinstance(input_audio_features, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_audio_features)}"
|
||||
)
|
||||
|
||||
return Qwen2_5OmniAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_audio_features,
|
||||
@ -702,19 +679,6 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -722,18 +686,6 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError(
|
||||
"Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -752,13 +704,6 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -766,13 +711,6 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError(
|
||||
"Incorrect type of video embeddings. "
|
||||
@ -787,23 +725,18 @@ class Qwen2_5OmniConditionalGenerationMixin:
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] = None,
|
||||
cached_audio_features: torch.Tensor = None,
|
||||
audio_hashes: list[str] | None = None,
|
||||
cached_audio_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_features = audio_input["input_features"]
|
||||
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
||||
if input_features.ndim == 3:
|
||||
assert input_features.shape[0] == 1
|
||||
input_features = input_features.squeeze(0)
|
||||
if audio_feature_lengths.ndim == 2:
|
||||
assert (
|
||||
audio_feature_lengths.shape[0] == 1
|
||||
or audio_feature_lengths.shape[1] == 1
|
||||
)
|
||||
if audio_feature_lengths.shape[0] == 1:
|
||||
audio_feature_lengths = audio_feature_lengths.squeeze(0)
|
||||
else:
|
||||
audio_feature_lengths = audio_feature_lengths.squeeze(1)
|
||||
|
||||
if audio_feature_lengths.shape[0] == 1:
|
||||
audio_feature_lengths = audio_feature_lengths.squeeze(0)
|
||||
elif audio_feature_lengths.shape[1] == 1:
|
||||
audio_feature_lengths = audio_feature_lengths.squeeze(1)
|
||||
else:
|
||||
raise AssertionError(audio_feature_lengths.shape)
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = (
|
||||
self.audio_tower._get_feat_extract_output_lengths(audio_feature_lengths)
|
||||
@ -867,6 +800,8 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
|
||||
SupportsMRoPE,
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"thinker.lm_head.": "language_model.lm_head.",
|
||||
|
||||
@ -1071,6 +1071,8 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
SupportsMultiModalPruning,
|
||||
SupportsMRoPE,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@ -1273,24 +1275,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
num_layers = len(self.language_model.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(
|
||||
f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})"
|
||||
)
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5_VLImageInputs | None:
|
||||
@ -1302,13 +1286,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -1316,13 +1293,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -1341,14 +1311,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
if second_per_grid_ts is not None and second_per_grid_ts.ndim == 2:
|
||||
second_per_grid_ts = second_per_grid_ts.squeeze(-1)
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -1357,13 +1319,6 @@ class Qwen2_5_VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
@ -313,6 +313,8 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessing
|
||||
dummy_inputs=Qwen2AudioDummyInputsBuilder,
|
||||
)
|
||||
class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("audio"):
|
||||
@ -346,16 +348,6 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2AudioInputs | None:
|
||||
@ -367,24 +359,11 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, Supports
|
||||
return None
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
audio_embeds = self._validate_and_reshape_mm_tensor(
|
||||
audio_embeds, "audio_embeds"
|
||||
)
|
||||
return Qwen2AudioEmbeddingInputs(
|
||||
type="audio_embeds", audio_embeds=audio_embeds
|
||||
)
|
||||
|
||||
if input_features is not None:
|
||||
input_features = self._validate_and_reshape_mm_tensor(
|
||||
input_features, "input_features"
|
||||
)
|
||||
feature_attention_mask = self._validate_and_reshape_mm_tensor(
|
||||
feature_attention_mask, "feature_attention_mask"
|
||||
)
|
||||
return Qwen2AudioFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=input_features,
|
||||
|
||||
@ -1213,6 +1213,8 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo])
|
||||
class Qwen2VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
# To ensure correct weight loading and mapping.
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
@ -1406,24 +1408,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(
|
||||
f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})"
|
||||
)
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2VLImageInputs | None:
|
||||
@ -1435,13 +1419,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -1449,13 +1426,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -1473,13 +1443,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -1487,13 +1450,6 @@ class Qwen2VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
@ -63,10 +63,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2_audio import (
|
||||
Qwen2AudioFeatureInputs,
|
||||
Qwen2AudioProcessingInfo,
|
||||
)
|
||||
from vllm.model_executor.models.qwen2_audio import Qwen2AudioProcessingInfo
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItems
|
||||
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems
|
||||
@ -86,6 +83,7 @@ from .interfaces import (
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen2_5_omni_thinker import (
|
||||
Qwen2_5OmniAudioFeatureInputs,
|
||||
Qwen2_5OmniConditionalGenerationMixin,
|
||||
Qwen2_5OmniThinkerDummyInputsBuilder,
|
||||
Qwen2_5OmniThinkerMultiModalProcessor,
|
||||
@ -101,6 +99,7 @@ from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
_merge_multimodal_embeddings,
|
||||
flatten_bn,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import (
|
||||
@ -1056,41 +1055,16 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
|
||||
|
||||
class Qwen3OmniMoeConditionalGenerationMixin(Qwen2_5OmniConditionalGenerationMixin):
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str, dim: int = 0
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if name == "feature_attention_mask":
|
||||
dim = -1
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return torch.concat(list(mm_input), dim=dim)
|
||||
else:
|
||||
if isinstance(mm_input[0], list):
|
||||
return torch.concat(
|
||||
[torch.concat(mm_input[i], dim=dim) for i in range(len(mm_input))],
|
||||
dim=dim,
|
||||
)
|
||||
else:
|
||||
return torch.concat(mm_input, dim=dim)
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: Qwen2AudioFeatureInputs,
|
||||
audio_hashes: list[str] = None,
|
||||
cached_audio_features: torch.Tensor = None,
|
||||
audio_input: Qwen2_5OmniAudioFeatureInputs,
|
||||
audio_hashes: list[str] | None = None,
|
||||
cached_audio_features: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
input_features = audio_input["input_features"]
|
||||
audio_feature_lengths = audio_input["audio_feature_lengths"]
|
||||
|
||||
if input_features.ndim == 3:
|
||||
assert input_features.shape[0] == 1
|
||||
input_features = input_features.squeeze(0)
|
||||
|
||||
if not isinstance(audio_feature_lengths, torch.Tensor):
|
||||
audio_feature_lengths = torch.cat(audio_feature_lengths)
|
||||
if audio_feature_lengths.ndim == 2:
|
||||
audio_feature_lengths = audio_feature_lengths.reshape(-1)
|
||||
audio_feature_lengths = flatten_bn(audio_feature_lengths, concat=True)
|
||||
|
||||
audio_feat_lengths, audio_output_lengths = _get_feat_extract_output_lengths(
|
||||
audio_feature_lengths
|
||||
@ -1117,6 +1091,8 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
|
||||
SupportsMRoPE,
|
||||
Qwen3OmniMoeConditionalGenerationMixin,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(
|
||||
orig_to_new_prefix={
|
||||
"thinker.lm_head.": "language_model.lm_head.",
|
||||
|
||||
@ -1175,6 +1175,8 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
|
||||
class Qwen3VLForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -1298,24 +1300,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
for idx in range(self.deepstack_num_level):
|
||||
self.deepstack_input_embeds[idx][:num_tokens].zero_()
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(
|
||||
f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})"
|
||||
)
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Qwen2_5_VLImageInputs | None:
|
||||
@ -1327,19 +1311,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values, "image pixel values"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image pixel values. "
|
||||
f"Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
@ -1347,18 +1318,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||
image_embeds, "image embeds"
|
||||
)
|
||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_thw, "image grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError(
|
||||
"Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
return Qwen2_5_VLImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
image_embeds=image_embeds,
|
||||
@ -1377,13 +1336,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
return None
|
||||
|
||||
if pixel_values_videos is not None:
|
||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||
pixel_values_videos, "video pixel values"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
return Qwen2_5_VLVideoPixelInputs(
|
||||
type="pixel_values_videos",
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
@ -1392,18 +1344,6 @@ class Qwen3VLForConditionalGeneration(
|
||||
)
|
||||
|
||||
if video_embeds is not None:
|
||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||
video_embeds, "video embeds"
|
||||
)
|
||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||
video_grid_thw, "video grid_thw"
|
||||
)
|
||||
|
||||
if not isinstance(video_embeds, torch.Tensor):
|
||||
raise ValueError(
|
||||
"Incorrect type of video embeddings. "
|
||||
f"Got type: {type(video_embeds)}"
|
||||
)
|
||||
return Qwen2_5_VLVideoEmbeddingInputs(
|
||||
type="video_embeds",
|
||||
video_embeds=video_embeds,
|
||||
|
||||
@ -58,7 +58,6 @@ from .interfaces import (
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen import QWenBaseModel, QWenModel
|
||||
from .utils import flatten_bn
|
||||
|
||||
|
||||
class QwenImagePixelInputs(TensorSchema):
|
||||
@ -703,6 +702,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||
class QwenVLForConditionalGeneration(
|
||||
QWenBaseModel, SupportsPP, SupportsLoRA, SupportsMultiModal
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"c_attn": ["c_attn"],
|
||||
"gate_up_proj": [
|
||||
@ -750,30 +751,19 @@ class QwenVLForConditionalGeneration(
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
|
||||
)
|
||||
|
||||
expected_h = expected_w = self.config.visual["image_size"]
|
||||
resolve_bindings = {"h": expected_h, "w": expected_w}
|
||||
|
||||
return QwenImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=flatten_bn(pixel_values, concat=True),
|
||||
data=pixel_values,
|
||||
resolve_bindings=resolve_bindings,
|
||||
)
|
||||
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}"
|
||||
)
|
||||
|
||||
return QwenImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds, concat=True),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user