diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index e8fe77e8d6c98..4b4cebb6a31c2 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `InternS1ForConditionalGeneration` | Intern-S1 | T + IE+ + VE+ | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + IE+ + (VE+) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ | | `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + IE+ + VE+ | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ | +| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ | +| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + IE+ + VE+ | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ | | `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I+ | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ | | `Llama4ForConditionalGeneration` | Llama 4 | T + I+ | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ | | `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + IE+ | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ | diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 4e879666f61d7..b104113b88213 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData: ) +# Keye-VL-1.5 +def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1.5-8B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + trust_remote_code=True, + limit_mm_per_prompt={modality: 1}, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompts = [ + ( + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n" + ) + for question in questions + ] + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + # Kimi-VL def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: assert modality == "image" @@ -1648,6 +1679,7 @@ model_example_map = { "interns1": run_interns1, "internvl_chat": run_internvl, "keye_vl": run_keye_vl, + "keye_vl1_5": run_keye_vl1_5, "kimi_vl": run_kimi_vl, "llama4": run_llama4, "llava": run_llava, diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index d9242efa85470..01c2905cf26d8 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData: ) +def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "Kwai-Keye/Keye-VL-1_5-8B" + + engine_args = EngineArgs( + model=model_name, + trust_remote_code=True, + max_model_len=8192, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [ + { + "role": "user", + "content": [ + *placeholders, + {"type": "text", "text": question}, + ], + }, + ] + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) + + prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_data = [fetch_image(url) for url in image_urls] + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=image_data, + ) + + def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "moonshotai/Kimi-VL-A3B-Instruct" @@ -1209,6 +1246,7 @@ model_example_map = { "interns1": load_interns1, "internvl_chat": load_internvl, "keye_vl": load_keye_vl, + "keye_vl1_5": load_keye_vl1_5, "kimi_vl": load_kimi_vl, "llama4": load_llama4, "llava": load_llava, diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 3ff4360b83345..16c0428c6d8f1 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -293,6 +293,7 @@ def _test_processing_correctness_one( "OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview", "OpenGVLab/InternVL3_5-30B-A3B", "Kwai-Keye/Keye-VL-8B-Preview", + "Kwai-Keye/Keye-VL-1_5-8B", "moonshotai/Kimi-VL-A3B-Instruct", "meta-llama/Llama-4-Scout-17B-16E-Instruct", "llava-hf/llava-1.5-7b-hf", diff --git a/tests/models/registry.py b/tests/models/registry.py index a37ffdc311514..3b5cec2dc7022 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = { "InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501 "KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501 trust_remote_code=True), + "KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501 + trust_remote_code=True), "KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501 extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501 trust_remote_code=True), diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 5686ec7b35de8..0ab4bc5375daf 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding): context_len=context_len, seq_len=seq_len, ) + elif "KeyeVL1_5" in hf_config.model_type: + return cls._keye_get_input_positions_tensor( + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + context_len=context_len, + seq_len=seq_len, + ) else: return cls._vl_get_input_positions_tensor( input_tokens=input_tokens, @@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding): len(input_tokens)).item() return llm_positions, mrope_position_delta + @classmethod + def _keye_get_input_positions_tensor( + cls, + input_tokens: list[int], + hf_config: PretrainedConfig, + image_grid_thw: Union[list[list[int]], torch.Tensor], + video_grid_thw: Union[list[list[int]], torch.Tensor], + context_len: int = 0, + seq_len: Optional[int] = None, + ) -> tuple[torch.Tensor, int]: + if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0: + video_grid_thw = video_grid_thw[0] + """Get mrope input positions and delta value (Keye series).""" + + def split_thw( + grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]: + """ + Split grid_thw along the t dimension. + + Args: + grid_thw: shape [N, 3] tensor or nested list of [t, h, w]. + + Returns: + List of [1, h, w] rows, repeated t times for each original row. + """ + + if isinstance(grid_thw, list): + grid_thw = torch.tensor(grid_thw, dtype=torch.long) + + if grid_thw.numel() == 0: + return [] + + t, hw = grid_thw[:, 0], grid_thw[:, 1:] + ones = torch.ones_like(hw[:, :1]) # [N,1] + out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0) + return out.tolist() + + video_grid_thw = split_thw(video_grid_thw) + + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + + image_nums = len(image_grid_thw) + frame_nums = len(video_grid_thw) + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_frames = image_nums, frame_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + frame_nums): + if remain_images > 0: + try: + ed_image = input_tokens.index(image_token_id, st) + except ValueError: + ed_image = len(input_tokens) + 1 + else: + ed_image = len(input_tokens) + 1 + if remain_frames > 0: + try: + ed_video = input_tokens.index(video_token_id, st) + except ValueError: + ed_video = len(input_tokens) + 1 + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_frames -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w)).long().flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + + return llm_positions, mrope_position_delta + @classmethod def _vl_get_input_positions_tensor( cls, diff --git a/vllm/model_executor/models/keye.py b/vllm/model_executor/models/keye.py index c6dbd62b905e1..710b805acb3ea 100644 --- a/vllm/model_executor/models/keye.py +++ b/vllm/model_executor/models/keye.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math +from abc import abstractmethod from collections.abc import Iterable, Mapping, Sequence from functools import partial -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Literal, Optional, TypeVar, Union import numpy as np import torch @@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend logger = init_logger(__name__) -_MAX_FRAMES_PER_VIDEO = 16 -_MAX_IMAGE_SIZE = 9999999 - def smart_resize( height: int, width: int, - factor: int = 28, - min_pixels: int = 28 * 28 * 130, - max_pixels: int = 28 * 28 * 1280, + factor: int, + min_pixels: int, + max_pixels: int, ): if height < factor: logger.warning( @@ -887,9 +885,9 @@ class Projector(nn.Module): def forward( self, - image_features: torch.Tensor, + image_features: Union[torch.Tensor, list[torch.Tensor]], image_grid_thw: list[tuple[int, int, int]], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, list[torch.Tensor]]: m1, m2 = self.merge_kernel_size if isinstance(image_features, (list, tuple)): processed_features = list() @@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser): class KeyeProcessingInfo(BaseProcessingInfo): + def get_max_image_size(self) -> int: + return 9999999 #_MAX_IMAGE_SIZE + + def get_max_frame_per_video(self) -> int: + return 16 #_MAX_FRAMES_PER_VIDEO + def get_image_processor(self, **kwargs: object): return self.get_hf_processor(**kwargs).image_processor @@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self, ) -> ImageSize: max_image_size, _ = self._get_vision_info( - image_width=_MAX_IMAGE_SIZE, - image_height=_MAX_IMAGE_SIZE, + image_width=self.get_max_image_size(), + image_height=self.get_max_image_size(), image_processor=None, ) return max_image_size @@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo): max_image_tokens) max_frames_per_video = min( max_total_frames // max(max_videos, 1), - _MAX_FRAMES_PER_VIDEO, + self.get_max_frame_per_video(), ) return max(max_frames_per_video, 1) @@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo): ) -class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): +_I = TypeVar("_I", bound=KeyeProcessingInfo) + + +class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) @@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]): return mm_data +class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]): + ... + + class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): def _get_data_parser(self) -> MultiModalDataParser: @@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]): return _keye_field_config(hf_inputs) -@MULTIMODAL_REGISTRY.register_processor( - KeyeMultiModalProcessor, - info=KeyeProcessingInfo, - dummy_inputs=KeyeDummyInputsBuilder, -) -class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, - SupportsPP): +class BaseKeyeModule(nn.Module): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, raise ValueError("Only image or video modality is supported") + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config: PretrainedConfig = vllm_config.model_config.hf_config @@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, quant_config=self._maybe_ignore_quant_config(quant_config), prefix=maybe_prefix(prefix, "visual"), ) - self.mlp_AR = Projector( + + self.mlp_AR = self._build_projector( config, config.vision_config, quant_config=self._maybe_ignore_quant_config(quant_config), @@ -1294,13 +1305,287 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): - return None - return quant_config + @abstractmethod + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + raise ValueError("Need projector") - def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, - name: str) -> torch.Tensor: + def _process_image_input(self, + image_input: Any) -> tuple[torch.Tensor, ...]: + siglip_position_ids = list() + image_grid_hws = list() + sample_indices = list() + cu_seqlens = [0] + + image_grid_thw = image_input["image_grid_thw"] + assert image_grid_thw.ndim == 2 + + for idx, thaw in enumerate(image_grid_thw): + thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + image_grid_hws.append(thw_tuple) + image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(image_position_ids) + sample_indices.append(torch.full((numel, ), idx, + dtype=torch.int64)) + cu_seqlens.append(cu_seqlens[-1] + numel) + + if image_input["type"] == "image_embeds": + raise ValueError( + "Image embeddings are not supported for this processing path.") + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + siglip_position_ids = torch.concat(siglip_position_ids, + dim=0).to(pixel_values.device) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( + pixel_values.device) + sample_indices = torch.concat(sample_indices, + dim=0).to(pixel_values.device) + + image_embeds = self.visual( + pixel_values=pixel_values, + image_grid_thw=image_grid_hws, + position_ids=siglip_position_ids, + vision_return_embed_list=False, + interpolate_pos_encoding=True, + sample_indices=sample_indices, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) + return image_embeds + + def _process_video_embeds( + self, + video_type: Literal["video_embeds", "pixel_values_videos"], + video_grid_thw: list[torch.Tensor], + pixel_values_videos: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, list[torch.Tensor]]: + siglip_position_ids = list() + video_grid_hws = list() + sample_indices = list() + cu_seqlens = [0] + + assert video_grid_thw.ndim == 2 + for idx, sub_thw in enumerate(video_grid_thw): + thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist()) + numel = np.prod(thw_tuple) + + video_grid_hws.append(thw_tuple) + video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) + siglip_position_ids.append(video_position_ids) + sample_indices.append(torch.full((numel, ), idx, + dtype=torch.int64)) + cu_seqlens.append(cu_seqlens[-1] + numel) + + if video_type == "video_embeds": + raise ValueError( + "Video embeddings are not supported for this processing path.") + else: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( + pixel_values_videos.device) + cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( + pixel_values_videos.device) + sample_indices = torch.concat(sample_indices, + dim=0).to(pixel_values_videos.device) + + video_embeds = self.visual( + pixel_values=pixel_values_videos, + image_grid_thw=video_grid_hws, + position_ids=siglip_position_ids, + vision_return_embed_list=True, + interpolate_pos_encoding=True, + sample_indices=sample_indices, + cu_seqlens=cu_seqlens, + use_rope=True, + window_size=-1, + ) + video_embeds = self.mlp_AR(video_embeds, video_grid_thw) + return video_embeds + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + for input_key in kwargs: + if (input_key in ("pixel_values", "image_embeds") + and "images" not in modalities): + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if (input_key in ("pixel_values_videos", "video_embeds") + and "videos" not in modalities): + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def get_multimodal_embeddings( + self, **kwargs: object) -> Optional[MultiModalEmbeddings]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + multimodal_embeddings, + [ + self.config.image_token_id, + self.config.video_token_id, + ], + ) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[Any] = None, + video_input: Optional[Any] = None, + ) -> torch.Tensor: + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Keye-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + if intermediate_tensors is not None: + inputs_embeds = None + + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input, + ) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix in multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="mlp_AR.", + tower_model="visual.", + ) + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeMultiModalProcessor, + info=KeyeProcessingInfo, + dummy_inputs=KeyeDummyInputsBuilder, +) +class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + return Projector(text_config, vision_config, quant_config, prefix) + + def _validate_and_reshape_mm_tensor( + self, mm_input: NestedTensors, + name: str) -> Union[torch.Tensor, list[torch.Tensor]]: if not isinstance(mm_input, (torch.Tensor, list)): raise ValueError(f"Incorrect type of {name}. " f"Got type: {type(mm_input)}") @@ -1388,257 +1673,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA, video_grid_thw=video_grid_thw, ) - def _process_image_input( - self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]: - siglip_position_ids = list() - image_grid_hws = list() - sample_indices = list() - cu_seqlens = [0] - - image_grid_thw = image_input["image_grid_thw"] - assert image_grid_thw.ndim == 2 - - for idx, thaw in enumerate(image_grid_thw): - thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - image_grid_hws.append(thw_tuple) - image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - siglip_position_ids.append(image_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) - cu_seqlens.append(cu_seqlens[-1] + numel) - - if image_input["type"] == "image_embeds": - raise ValueError( - "Image embeddings are not supported for this processing path.") - else: - pixel_values = image_input["pixel_values"].type(self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, - dim=0).to(pixel_values.device) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values.device) - - image_embeds = self.visual( - pixel_values=pixel_values, - image_grid_thw=image_grid_hws, - position_ids=siglip_position_ids, - vision_return_embed_list=False, - interpolate_pos_encoding=True, - sample_indices=sample_indices, - cu_seqlens=cu_seqlens, - use_rope=True, - window_size=-1, - ) - image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw)) - return image_embeds - def _process_video_input( self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]: - siglip_position_ids = list() - video_grid_hws = list() - sample_indices = list() - cu_seqlens = [0] - + video_type = video_input["type"] video_grid_thw = video_input["video_grid_thw"] - assert video_grid_thw.ndim == 2 + pixel_values_videos = video_input.get("pixel_values_videos", None) - for idx, thaw in enumerate(video_grid_thw): - thw_tuple = tuple(thaw.detach().cpu().numpy().tolist()) - numel = np.prod(thw_tuple) - - video_grid_hws.append(thw_tuple) - video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:]) - siglip_position_ids.append(video_position_ids) - sample_indices.append(torch.full((numel, ), idx, - dtype=torch.int64)) - cu_seqlens.append(cu_seqlens[-1] + numel) - - if video_input["type"] == "video_embeds": - raise ValueError( - "Video embeddings are not supported for this processing path.") - else: - pixel_values_videos = video_input["pixel_values_videos"].type( - self.visual.dtype) - siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to( - pixel_values_videos.device) - cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to( - pixel_values_videos.device) - sample_indices = torch.concat(sample_indices, - dim=0).to(pixel_values_videos.device) - - video_embeds = self.visual( - pixel_values=pixel_values_videos, - image_grid_thw=video_grid_hws, - position_ids=siglip_position_ids, - vision_return_embed_list=True, - interpolate_pos_encoding=True, - sample_indices=sample_indices, - cu_seqlens=cu_seqlens, - use_rope=True, - window_size=-1, - ) - video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw)) - return video_embeds - - def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: - modalities = {} - - for input_key in kwargs: - if (input_key in ("pixel_values", "image_embeds") - and "images" not in modalities): - modalities["images"] = self._parse_and_validate_image_input( - **kwargs) - if (input_key in ("pixel_values_videos", "video_embeds") - and "videos" not in modalities): - modalities["videos"] = self._parse_and_validate_video_input( - **kwargs) - - return modalities - - def get_language_model(self) -> torch.nn.Module: - return self.language_model - - def get_multimodal_embeddings( - self, **kwargs: object) -> Optional[MultiModalEmbeddings]: - - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - if not modalities: - return None - - multimodal_embeddings: tuple[torch.Tensor, ...] = () - - for modality in modalities: - if modality == "images": - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - multimodal_embeddings += vision_embeddings - if modality == "videos": - video_input = modalities["videos"] - video_embeddings = self._process_video_input(video_input) - multimodal_embeddings += video_embeddings - return multimodal_embeddings - - def get_input_embeddings( - self, - input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, - ) -> torch.Tensor: - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if multimodal_embeddings is not None: - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - multimodal_embeddings, - [ - self.config.image_token_id, - self.config.video_token_id, - ], - ) - return inputs_embeds - - def get_input_embeddings_v0( - self, - input_ids: torch.Tensor, - image_input: Optional[KeyeImagePixelInputs] = None, - video_input: Optional[KeyeVideoPixelInputs] = None, - ) -> torch.Tensor: - inputs_embeds = self.get_input_embeddings(input_ids) - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) - - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - return inputs_embeds - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: object, - ) -> Union[torch.Tensor, IntermediateTensors]: - """Run forward pass for Qwen2-VL. - - Args: - input_ids: Flattened (concatenated) input_ids corresponding to a - batch. - positions: Flattened (concatenated) position ids corresponding to a - batch. - **NOTE**: If mrope is enabled (default setting for Qwen2-VL - opensource models), the shape will be `(3, seq_len)`, - otherwise it will be `(seq_len,). - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. - pixel_values_videos: Pixel values of videos to be fed to a model. - `None` if no videos are passed. - video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. - `None` if no videos are passed. - """ - - if intermediate_tensors is not None: - inputs_embeds = None - - elif inputs_embeds is None: - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if image_input is None and video_input is None: - inputs_embeds = None - else: - if uses_mrope(self.config): - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - inputs_embeds = self.get_input_embeddings_v0( - input_ids, - image_input=image_input, - video_input=video_input, - ) - input_ids = None - - hidden_states = self.language_model.model( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - ) - return hidden_states - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.language_model.compute_logits(hidden_states, - sampling_metadata) - - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - - loader = AutoWeightsLoader(self) - return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) - - def get_mm_mapping(self) -> MultiModelKeys: - """Get the module prefix in multimodal models.""" - return MultiModelKeys.from_string_field( - language_model="language_model", - connector="visual.", - tower_model="mlp_AR.", - ) + return tuple( + self._process_video_embeds(video_type, video_grid_thw, + pixel_values_videos)) diff --git a/vllm/model_executor/models/keye_vl1_5.py b/vllm/model_executor/models/keye_vl1_5.py new file mode 100644 index 0000000000000..605c6d3eaf643 --- /dev/null +++ b/vllm/model_executor/models/keye_vl1_5.py @@ -0,0 +1,601 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from collections.abc import Mapping, Sequence +from functools import partial +from typing import Annotated, Any, Literal, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from transformers import PretrainedConfig +from transformers.activations import GELUActivation +from transformers.feature_extraction_utils import BatchFeature + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors +from vllm.multimodal.inputs import (ImageItem, ModalityData, + MultiModalFieldConfig, + MultiModalKwargsItems, VideoItem) +from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems, + MultiModalDataItems, MultiModalDataParser) +from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, + PromptUpdateDetails) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .keye import (BaseKeyeModule, BaseMultiModalProcessor, + KeyeBaseDummyInputsBuilder, KeyeProcessingInfo) + +logger = init_logger(__name__) + + +def split_thw(grid_thw: torch.Tensor) -> torch.Tensor: + """ + Split grid_thw in t dimension. + + Args: + grid_thw: [N, 3] tensor of [t, h, w] + + Returns: + [Σt, 3] tensor where each row is [1, h, w] + + Example: + >>> grid_thw = torch.tensor([[2, 3, 4], [1, 5, 6]]) + >>> split_thw(grid_thw) + tensor([[1, 3, 4], + [1, 3, 4], + [1, 5, 6]]) + """ + t = grid_thw[:, 0] + h_w = grid_thw[:, 1:] + ones = torch.ones_like(h_w[:, :1]) + return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0) + + +def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int], + torch.Tensor]): + """ + Return num_patches per video. + + Args: + t: tensor with shape [N, ...] where each item is a list/tensor + cu_seqlens: list indicating the boundaries of groups + + Returns: + list of ints representing the sum of products for each group + + Examples: + >>> # Suppose there are 2 videos with a total of 3 grids + >>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches + ... [2, 2, 2], # grid 1: 2*2*2=8 patches + ... [1, 1, 1]]) # grid 2: 1*1*1=1 patches + >>> num_frames = [2, 1] # The first video contains 2 grids, + the second contains 1 grid. + >>> get_num_patches(grid_thw, num_frames) + tensor([16, 1]) # Total patches for first video: 8+8=16, + second video: 1. + """ + + assert len(grid_thw.shape) == 2 + if isinstance(num_frames, torch.Tensor): + num_frames = num_frames.clone().tolist() + + num_grids_per_frame = grid_thw.prod(dim=1) + start_idx_per_video = [0, *itertools.accumulate(num_frames)] + num_patches = [ + num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]]. + sum() for i in range(len(num_frames)) + ] + return torch.stack(num_patches) if num_patches else torch.zeros( + 0, dtype=grid_thw.dtype, device=grid_thw.device) + + +class KeyeVL1_5ImagePixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - np: Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["pixel_values"] + + pixel_values: Annotated[ + torch.Tensor, + TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +class KeyeVL1_5ImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of image features + - hs: Hidden size (must match the hidden size of language model + backbone) + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["image_embeds"] + image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)] + + +KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs, + KeyeVL1_5ImageEmbeddingInputs] + + +class KeyeVL1_5VideoPixelInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - np: Number of patches + - c: Number of channels + - ps: Patch size + - ni: Number of images + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["pixel_values_videos"] + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + + num_frames: torch.Tensor + + +class KeyeVL1_5VideoEmbeddingInputs(TensorSchema): + """ + Dimensions: + - nf: Number of video features + - hs: Hidden size (must match the hidden size of language model + backbone) + - nv: Number of videos + - g: Grid dimensions (3 for t, h, w) + """ + type: Literal["video_embeds"] + video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")] + video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)] + num_frames: torch.Tensor + + +KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs, + KeyeVL1_5VideoEmbeddingInputs] + + +class KeyeVL1_5Projector(nn.Module): + + def __init__( + self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.text_config = text_config + self.vision_config = vision_config + self.merge_kernel_size = (2, 2) + + self.hidden_size = (self.vision_config.hidden_size * + self.merge_kernel_size[0] * + self.merge_kernel_size[1]) + + self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05) + self.act = GELUActivation() + + self.linear_1 = ColumnParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1", + ) + self.linear_2 = RowParallelLinear( + self.hidden_size, + self.text_config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2", + ) + + def forward( + self, + image_features: Union[torch.Tensor, tuple[torch.Tensor], + list[torch.Tensor]], + image_grid_thw: list[tuple[int, int, int]], + ) -> Union[torch.Tensor, list[torch.Tensor]]: + m1, m2 = self.merge_kernel_size + if isinstance(image_features, (list, tuple)): + processed_features = list() + for image_feature, image_grid in zip(image_features, + image_grid_thw): + t, h, w = image_grid + image_feature = rearrange( + image_feature, + "(t h p1 w p2) d -> (t h w) (p1 p2 d)", + t=t, + h=h // m1, + p1=m1, + w=w // m2, + p2=m2, + ) + image_feature = self.pre_norm(image_feature) + hidden_states, _ = self.linear_1(image_feature) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) + processed_features.append(hidden_states) + + return processed_features + + dims = image_features.shape[:-1] + dim = image_features.shape[-1] + image_features = image_features.view(np.prod(dims), dim) + hidden_states = self.pre_norm(image_features.view( + -1, self.hidden_size)) + hidden_states = self.linear_1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + + return hidden_states.view(*dims, -1) + + +class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo): + + def get_max_frame_per_video(self) -> int: + return 2048 + + def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]: + return {"image": None, "video": 1} + + +def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ): + image_grid_thw = hf_inputs.get("image_grid_thw", + torch.empty((0, 3), dtype=torch.int64)) + image_grid_sizes = image_grid_thw.prod(-1) + + video_grid_thw = hf_inputs.get("video_grid_thw", + torch.empty((0, 3), dtype=torch.int64)) + video_grid_thw = split_thw(video_grid_thw) + num_frames = hf_inputs.get("num_frames", + video_grid_thw[:, 0]).clone().tolist() + + video_num_patches = get_num_patches(video_grid_thw, num_frames) + + video_num_grids = [] + if len(num_frames) > 0: + i = 0 + j = 1 + cur_frames = num_frames[i] + for t, _, _ in video_grid_thw.tolist(): + cur_frames -= t + if cur_frames == 0: + video_num_grids.append(j) + i += 1 + if i < len(num_frames): + cur_frames = num_frames[i] + j = 1 + else: + j += 1 + video_num_grids = torch.tensor(video_num_grids) + return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_embeds=MultiModalFieldConfig.flat_from_sizes( + "image", image_grid_sizes), + image_grid_thw=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_embeds=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_patches), + video_grid_thw=MultiModalFieldConfig.flat_from_sizes( + "video", video_num_grids), + num_frames=MultiModalFieldConfig.batched("video")) + + +class KeyeVL1_5MultiModalDataParser(MultiModalDataParser): + + def _parse_image_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="image", + required_fields={ + "image_embeds", + "image_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_image_data(data) + + def _parse_video_data( + self, + data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], + ) -> ModalityDataItems[Any, Any]: + if isinstance(data, dict): + return DictEmbeddingItems( + data, + modality="video", + required_fields={ + "video_embeds", + "video_grid_thw", + }, + fields_factory=_keye_field_config, + ) + + return super()._parse_video_data(data) + + +class KeyeVL1_5MultiModalProcessor( + BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]): + + def _get_data_parser(self) -> MultiModalDataParser: + return KeyeVL1_5MultiModalDataParser() + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_processor = self.info.get_image_processor( + **hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + image_token_id = vocab[hf_processor.image_token] + video_token_id = vocab[hf_processor.video_token] + placeholder = {"image": image_token_id, "video": video_token_id} + merge_length = image_processor.merge_size**2 + + out_mm_kwargs_data = out_mm_kwargs.get_data() + frame_types: list[torch.Tensor] = \ + hf_processor_mm_kwargs.get("frame_types", None) + timestamps: list[torch.Tensor] = \ + hf_processor_mm_kwargs.get("timestamps", None) + num_videos = mm_items.get_count("video", strict=False) + + if frame_types is None: + frame_types = [None] * num_videos + assert len(frame_types) == num_videos, \ + f"Number of frame_types={len(frame_types)} " \ + f"doesn't equal to number of videos={num_videos}" + if timestamps is None: + timestamps = [None] * num_videos + assert len(timestamps) == num_videos, \ + f"Number of timestamps={len(timestamps)} " \ + f"doesn't equal to number of videos={num_videos}" + + video_grid_thw = out_mm_kwargs_data.get( + 'video_grid_thw', torch.empty((0, 3), dtype=torch.int64)) + num_frames = out_mm_kwargs_data.get( + 'num_frames', torch.tensor([], dtype=torch.int64)) + + assert len(num_frames) == num_videos, \ + f"Size of num_frames={len(num_frames)} " \ + f"doesn't equal to number of videos={num_videos}" + + video_grid_hws = split_thw(video_grid_thw) + assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], ( + f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}" + f"doesn't equal to num of frames.") + + cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()), + dim=-1) + + def get_replacement_keye(item_idx: int, modality: str): + """ + Args: + item_idx(int): The item index of modality to replace + modality(str): The modality + """ + if modality == "image": + out_item = out_mm_kwargs[modality][item_idx] + grid_thw = out_item[f"{modality}_grid_thw"].data + assert isinstance(grid_thw, torch.Tensor) + + num_tokens = int(grid_thw.prod()) // merge_length + return [image_token_id] * num_tokens + elif modality == "video": + placeholders = [] + video_timestamps = timestamps[item_idx] + video_frame_types = frame_types[item_idx] + grid_thw = video_grid_hws[ + cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]] + + nframes = grid_thw.shape[0] + + if video_timestamps is None: + video_timestamps = [""] * nframes + else: + video_timestamps = [ + format(ts, ".1f") for ts in video_timestamps + ] + + if video_frame_types is None: + video_frame_types = [0] * nframes + for i, sub_thw in enumerate(grid_thw): + s = f"{hf_processor.frame_token}{video_timestamps[i]}" + if video_frame_types[i] == 1: + s += hf_processor.fast_start + placeholders.extend(tokenizer.encode(s)) + num_frame_tokens = int(sub_thw.prod()) // merge_length + placeholders.extend([video_token_id] * num_frame_tokens) + if video_frame_types[i] == 1: + placeholders.append(vocab[hf_processor.fast_end]) + + return PromptUpdateDetails.select_token_id( + placeholders, embed_token_id=video_token_id) + else: + raise ValueError(f"Unsupported modality {modality}") + + return [ + PromptReplacement( + modality=modality, + target=[placeholder[modality]], + replacement=partial(get_replacement_keye, modality=modality), + ) for modality in ("image", "video") + ] + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return _keye_field_config(hf_inputs) + + +class KeyeVL1_5DummyInputsBuilder( + KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]): + ... + + +@MULTIMODAL_REGISTRY.register_processor( + KeyeVL1_5MultiModalProcessor, + info=KeyeVL1_5ProcessingInfo, + dummy_inputs=KeyeVL1_5DummyInputsBuilder, +) +class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal, + SupportsLoRA, SupportsPP): + + def _build_projector(self, + text_config: PretrainedConfig, + vision_config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> nn.Module: + return KeyeVL1_5Projector(text_config, vision_config, quant_config, + prefix) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config: PretrainedConfig = vllm_config.model_config.hf_config + self.merge_size = config.vision_config.spatial_merge_size + super().__init__(vllm_config=vllm_config, prefix=prefix) + + def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors, + expected_dim: int, name: str): + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == expected_dim: + return mm_input + elif mm_input.ndim == expected_dim + 1: + return torch.concat(list(mm_input)) + else: + raise ValueError( + f"{name} should be {expected_dim}D or " + f"batched {expected_dim}D tensor." + f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})") + else: + return torch.concat(list(mm_input)) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, expected_dim=4, name="image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, expected_dim=2, name="image grid_thw") + + return KeyeVL1_5ImagePixelInputs( + type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, expected_dim=2, name="image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, expected_dim=2, name="image grid_thw") + + return KeyeVL1_5ImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + num_frames = kwargs.pop("num_frames", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, + expected_dim=4, + name="video pixel values", + ) + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, expected_dim=2, name="video grid_thw") + + num_frames = self._validate_and_reshape_mm_tensor( + num_frames, expected_dim=1, name="video num frames") + + return KeyeVL1_5VideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + num_frames=num_frames) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, expected_dim=2, name="video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, expected_dim=2, name="video grid_thw") + + return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw, + num_frames=num_frames) + + def _process_video_input( + self, + video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]: + video_type = video_input["type"] + video_grid_thw = split_thw(video_input["video_grid_thw"]) + pixel_values_videos = video_input.get("pixel_values_videos", None) + + video_embeds = self._process_video_embeds(video_type, video_grid_thw, + pixel_values_videos) + video_embeds = torch.concat(video_embeds, dim=0) + + num_frames = video_input["num_frames"].clone().tolist() + + num_patches = get_num_patches(video_grid_thw, num_frames).tolist() + + patch_cu_seqlens = torch.cumsum( + torch.tensor([0] + num_patches).detach().clone(), dim=-1) + patch_cu_seqlens = torch.div(patch_cu_seqlens, + self.merge_size**2, + rounding_mode="floor") + + new_video_embeds = [] + for idx in range(patch_cu_seqlens.shape[0] - 1): + start = patch_cu_seqlens[idx] + end = patch_cu_seqlens[idx + 1] + new_video_embeds.append(video_embeds[start:end]) + return tuple(new_video_embeds) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 98115f8623563..edb7f24214406 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = { "Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"), "SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501 "KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"), + "KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501 "RForConditionalGeneration": ("rvl", "RForConditionalGeneration"), "KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501 "Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),