mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 11:46:04 +08:00
[Model] Use merge_by_field_config for MM models (H-L) (#26230)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
119f00630b
commit
59a85c366e
@ -548,7 +548,7 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=8192,
|
max_model_len=32768,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
limit_mm_per_prompt={"image": len(image_urls)},
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
)
|
)
|
||||||
|
|||||||
@ -53,7 +53,7 @@ from .idefics2_vision_model import (
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
|
||||||
from .llama import LlamaModel
|
from .llama import LlamaModel
|
||||||
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
from .utils import AutoWeightsLoader, maybe_prefix
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ImagePixelInputs(TensorSchema):
|
class Idefics3ImagePixelInputs(TensorSchema):
|
||||||
@ -67,7 +67,7 @@ class Idefics3ImagePixelInputs(TensorSchema):
|
|||||||
"""
|
"""
|
||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
|
||||||
pixel_attention_mask: torch.Tensor
|
pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
|
||||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||||
|
|
||||||
|
|
||||||
@ -569,6 +569,8 @@ class Idefics3Model(nn.Module):
|
|||||||
dummy_inputs=Idefics3DummyInputsBuilder)
|
dummy_inputs=Idefics3DummyInputsBuilder)
|
||||||
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsLoRA):
|
SupportsLoRA):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -621,37 +623,21 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
if not isinstance(image_embeds, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of image embeddings. "
|
|
||||||
f"Got type: {type(image_embeds)}")
|
|
||||||
|
|
||||||
return Idefics3ImageEmbeddingInputs(
|
return Idefics3ImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
data=flatten_bn(image_embeds, concat=True),
|
data=image_embeds,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of pixel values. "
|
|
||||||
f"Got type: {type(pixel_values)}")
|
|
||||||
|
|
||||||
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
|
pixel_attention_mask = kwargs.pop("pixel_attention_mask")
|
||||||
if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of pixel_attention_mask. "
|
|
||||||
f"Got type: {type(pixel_attention_mask)}")
|
|
||||||
|
|
||||||
num_patches = kwargs.pop("num_patches")
|
num_patches = kwargs.pop("num_patches")
|
||||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
|
||||||
raise ValueError("Incorrect type of num_patches. "
|
|
||||||
f"Got type: {type(num_patches)}")
|
|
||||||
|
|
||||||
expected_h = expected_w = self.config.vision_config.image_size
|
expected_h = expected_w = self.config.vision_config.image_size
|
||||||
|
|
||||||
return Idefics3ImagePixelInputs(
|
return Idefics3ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
pixel_values=pixel_values,
|
||||||
pixel_attention_mask=flatten_bn(pixel_attention_mask,
|
pixel_attention_mask=pixel_attention_mask,
|
||||||
concat=True),
|
num_patches=num_patches,
|
||||||
num_patches=flatten_bn(num_patches, concat=True),
|
|
||||||
resolve_bindings={
|
resolve_bindings={
|
||||||
"h": expected_h,
|
"h": expected_h,
|
||||||
"w": expected_w
|
"w": expected_w
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
MultiModalDataDict, MultiModalFieldConfig,
|
MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems, VideoItem)
|
MultiModalKwargsItems, VideoItem)
|
||||||
@ -42,7 +42,6 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
|||||||
PromptUpdate)
|
PromptUpdate)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import is_list_of
|
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||||
@ -100,8 +99,7 @@ def smart_resize(
|
|||||||
class KeyeImagePixelInputs(TensorSchema):
|
class KeyeImagePixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- b: Batch size
|
- bnp: Batch size * Number of patches
|
||||||
- np: Number of patches
|
|
||||||
- c: Number of channels
|
- c: Number of channels
|
||||||
- ps: Patch size
|
- ps: Patch size
|
||||||
- ni: Number of images
|
- ni: Number of images
|
||||||
@ -110,7 +108,7 @@ class KeyeImagePixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values"]
|
type: Literal["pixel_values"]
|
||||||
pixel_values: Annotated[
|
pixel_values: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
@ -134,8 +132,7 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
|
|||||||
class KeyeVideoPixelInputs(TensorSchema):
|
class KeyeVideoPixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- b: Batch size
|
- bnp: Batch size * Number of patches
|
||||||
- np: Number of patches
|
|
||||||
- c: Number of channels
|
- c: Number of channels
|
||||||
- ps: Patch size
|
- ps: Patch size
|
||||||
- ni: Number of images
|
- ni: Number of images
|
||||||
@ -144,7 +141,7 @@ class KeyeVideoPixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: Annotated[
|
pixel_values_videos: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("b", "np", 3, "ps", "ps", dynamic_dims={"np"})]
|
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||||
|
|
||||||
|
|
||||||
@ -1258,6 +1255,8 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
|||||||
|
|
||||||
|
|
||||||
class BaseKeyeModule(nn.Module):
|
class BaseKeyeModule(nn.Module):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -1524,28 +1523,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
prefix: str = "") -> nn.Module:
|
prefix: str = "") -> nn.Module:
|
||||||
return Projector(text_config, vision_config, quant_config, prefix)
|
return Projector(text_config, vision_config, quant_config, prefix)
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(
|
|
||||||
self, mm_input: NestedTensors,
|
|
||||||
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
|
||||||
f"Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == 2:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim == 5:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim != 3:
|
|
||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
||||||
f"Got ndim: {mm_input.ndim} "
|
|
||||||
f"(shape={mm_input.shape})")
|
|
||||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
|
||||||
elif is_list_of(mm_input, torch.Tensor):
|
|
||||||
if all(p.dim() == 4 for p in mm_input) or all(p.dim() == 2
|
|
||||||
for p in mm_input):
|
|
||||||
return mm_input
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
self, **kwargs: object) -> Optional[KeyeImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -1556,11 +1533,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values, "image pixel values")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
return KeyeImagePixelInputs(
|
return KeyeImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@ -1568,11 +1540,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_embeds, "image embeds")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, "image grid_thw")
|
|
||||||
|
|
||||||
return KeyeImageEmbeddingInputs(
|
return KeyeImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
@ -1589,13 +1556,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values_videos,
|
|
||||||
"video pixel values",
|
|
||||||
)
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, "video grid_thw")
|
|
||||||
|
|
||||||
return KeyeVideoPixelInputs(
|
return KeyeVideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
@ -1603,11 +1563,6 @@ class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if video_embeds is not None:
|
if video_embeds is not None:
|
||||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_embeds, "video embeds")
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, "video grid_thw")
|
|
||||||
|
|
||||||
return KeyeVideoEmbeddingInputs(
|
return KeyeVideoEmbeddingInputs(
|
||||||
type="video_embeds",
|
type="video_embeds",
|
||||||
video_embeds=video_embeds,
|
video_embeds=video_embeds,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
MultiModalFieldConfig,
|
MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems, VideoItem)
|
MultiModalKwargsItems, VideoItem)
|
||||||
@ -100,8 +100,7 @@ def get_num_patches(grid_thw: torch.Tensor,
|
|||||||
class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- b: Batch size
|
- bnp: Batch size * Number of patches
|
||||||
- np: Number of patches
|
|
||||||
- c: Number of channels
|
- c: Number of channels
|
||||||
- ps: Patch size
|
- ps: Patch size
|
||||||
- ni: Number of images
|
- ni: Number of images
|
||||||
@ -111,7 +110,7 @@ class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
|||||||
|
|
||||||
pixel_values: Annotated[
|
pixel_values: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||||
|
|
||||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
@ -137,8 +136,7 @@ KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
|
|||||||
class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
- b: Batch size
|
- bnp: Batch size * Number of patches
|
||||||
- np: Number of patches
|
|
||||||
- c: Number of channels
|
- c: Number of channels
|
||||||
- ps: Patch size
|
- ps: Patch size
|
||||||
- ni: Number of images
|
- ni: Number of images
|
||||||
@ -147,7 +145,7 @@ class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
|||||||
type: Literal["pixel_values_videos"]
|
type: Literal["pixel_values_videos"]
|
||||||
pixel_values_videos: Annotated[
|
pixel_values_videos: Annotated[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
TensorShape("bnp", 3, "ps", "ps", dynamic_dims={"bnp"})]
|
||||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||||
|
|
||||||
num_frames: torch.Tensor
|
num_frames: torch.Tensor
|
||||||
@ -483,24 +481,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
self.merge_size = config.vision_config.spatial_merge_size
|
self.merge_size = config.vision_config.spatial_merge_size
|
||||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
|
|
||||||
expected_dim: int, name: str):
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
|
||||||
f"Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == expected_dim:
|
|
||||||
return mm_input
|
|
||||||
elif mm_input.ndim == expected_dim + 1:
|
|
||||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"{name} should be {expected_dim}D or "
|
|
||||||
f"batched {expected_dim}D tensor."
|
|
||||||
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
|
|
||||||
else:
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
||||||
pixel_values = kwargs.pop("pixel_values", None)
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
@ -511,11 +491,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values is not None:
|
if pixel_values is not None:
|
||||||
pixel_values = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values, expected_dim=4, name="image pixel values")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, expected_dim=2, name="image grid_thw")
|
|
||||||
|
|
||||||
return KeyeVL1_5ImagePixelInputs(
|
return KeyeVL1_5ImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
@ -523,11 +498,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
)
|
)
|
||||||
|
|
||||||
if image_embeds is not None:
|
if image_embeds is not None:
|
||||||
image_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_embeds, expected_dim=2, name="image embeds")
|
|
||||||
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_thw, expected_dim=2, name="image grid_thw")
|
|
||||||
|
|
||||||
return KeyeVL1_5ImageEmbeddingInputs(
|
return KeyeVL1_5ImageEmbeddingInputs(
|
||||||
type="image_embeds",
|
type="image_embeds",
|
||||||
image_embeds=image_embeds,
|
image_embeds=image_embeds,
|
||||||
@ -545,17 +515,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if pixel_values_videos is not None:
|
if pixel_values_videos is not None:
|
||||||
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
|
||||||
pixel_values_videos,
|
|
||||||
expected_dim=4,
|
|
||||||
name="video pixel values",
|
|
||||||
)
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, expected_dim=2, name="video grid_thw")
|
|
||||||
|
|
||||||
num_frames = self._validate_and_reshape_mm_tensor(
|
|
||||||
num_frames, expected_dim=1, name="video num frames")
|
|
||||||
|
|
||||||
return KeyeVL1_5VideoPixelInputs(
|
return KeyeVL1_5VideoPixelInputs(
|
||||||
type="pixel_values_videos",
|
type="pixel_values_videos",
|
||||||
pixel_values_videos=pixel_values_videos,
|
pixel_values_videos=pixel_values_videos,
|
||||||
@ -563,11 +522,6 @@ class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
|||||||
num_frames=num_frames)
|
num_frames=num_frames)
|
||||||
|
|
||||||
if video_embeds is not None:
|
if video_embeds is not None:
|
||||||
video_embeds = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_embeds, expected_dim=2, name="video embeds")
|
|
||||||
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
|
||||||
video_grid_thw, expected_dim=2, name="video grid_thw")
|
|
||||||
|
|
||||||
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
|
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
|
||||||
video_embeds=video_embeds,
|
video_embeds=video_embeds,
|
||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
|
|||||||
@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
|||||||
dummy_inputs=KimiVLDummyInputsBuilder)
|
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP):
|
SupportsPP):
|
||||||
|
merge_by_field_config = True
|
||||||
|
|
||||||
supports_encoder_tp_data = True
|
supports_encoder_tp_data = True
|
||||||
|
|
||||||
@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
config.vocab_size, logit_scale)
|
config.vocab_size, logit_scale)
|
||||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||||
|
|
||||||
# ref: qwen2_vl.py
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
|
||||||
name: str) -> torch.Tensor:
|
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
|
||||||
f"Got type: {type(mm_input)}")
|
|
||||||
if isinstance(mm_input, torch.Tensor):
|
|
||||||
if mm_input.ndim == 2:
|
|
||||||
return mm_input
|
|
||||||
if mm_input.ndim != 3:
|
|
||||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
|
||||||
f"Got ndim: {mm_input.ndim} "
|
|
||||||
f"(shape={mm_input.shape})")
|
|
||||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
|
||||||
else:
|
|
||||||
return torch.concat(mm_input)
|
|
||||||
|
|
||||||
def _parse_and_validate_image_input(
|
def _parse_and_validate_image_input(
|
||||||
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
||||||
# image input type must be pixel values now
|
# image input type must be pixel values now
|
||||||
@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_grid_hws = self._validate_and_reshape_mm_tensor(
|
|
||||||
image_grid_hws, "image grid hws")
|
|
||||||
# pixel_values may have complex shapes
|
|
||||||
num_channels = 3
|
|
||||||
patch_size = self.config.vision_config.patch_size
|
|
||||||
if isinstance(pixel_values, list):
|
|
||||||
pixel_values = torch.cat([
|
|
||||||
x.reshape(-1, num_channels, patch_size, patch_size)
|
|
||||||
for x in pixel_values
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
|
||||||
patch_size)
|
|
||||||
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
|
||||||
|
|
||||||
return KimiVLImagePixelInputs(
|
return KimiVLImagePixelInputs(
|
||||||
type="pixel_values",
|
type="pixel_values",
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
|
|||||||
@ -164,7 +164,9 @@ class TensorSchema:
|
|||||||
|
|
||||||
if len(actual_shape) != len(expected_shape):
|
if len(actual_shape) != len(expected_shape):
|
||||||
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
raise ValueError(f"{field_name} has rank {len(actual_shape)} "
|
||||||
f"but expected {len(expected_shape)}")
|
f"but expected {len(expected_shape)}. "
|
||||||
|
f"Expected shape: {expected_shape}, "
|
||||||
|
f"but got {actual_shape}")
|
||||||
|
|
||||||
for i, dim in enumerate(expected_shape):
|
for i, dim in enumerate(expected_shape):
|
||||||
if dim in dynamic_dims:
|
if dim in dynamic_dims:
|
||||||
@ -172,7 +174,9 @@ class TensorSchema:
|
|||||||
elif isinstance(dim, int):
|
elif isinstance(dim, int):
|
||||||
if actual_shape[i] != dim:
|
if actual_shape[i] != dim:
|
||||||
raise ValueError(f"{field_name} dim[{i}] expected "
|
raise ValueError(f"{field_name} dim[{i}] expected "
|
||||||
f"{dim}, got {actual_shape[i]}")
|
f"{dim}, got {actual_shape[i]}. "
|
||||||
|
f"Expected shape: {expected_shape}, "
|
||||||
|
f"but got {actual_shape}")
|
||||||
elif isinstance(dim, str):
|
elif isinstance(dim, str):
|
||||||
if dim in shape_env:
|
if dim in shape_env:
|
||||||
if actual_shape[i] != shape_env[dim]:
|
if actual_shape[i] != shape_env[dim]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user