mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 17:51:21 +08:00
[Model]: support KeyeVL-1_5-8B (#23838)
Signed-off-by: wangruitao <wangruitao@kuaishou.com> Co-authored-by: wangruitao <wangruitao@kuaishou.com>
This commit is contained in:
parent
3e330fcb21
commit
7c8271cd1e
@ -634,7 +634,8 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
|
|||||||
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternS1ForConditionalGeneration` | Intern-S1 | T + I<sup>E+</sup> + V<sup>E+</sup> | `internlm/Intern-S1`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternVLChatModel` | InternVL 3.5, InternVL 3.0, InternVideo 2.5, InternVL 2.5, Mono-InternVL, InternVL 2.0 | T + I<sup>E+</sup> + (V<sup>E+</sup>) | `OpenGVLab/InternVL3_5-14B`, `OpenGVLab/InternVL3-9B`, `OpenGVLab/InternVideo2_5_Chat_8B`, `OpenGVLab/InternVL2_5-4B`, `OpenGVLab/Mono-InternVL-2B`, `OpenGVLab/InternVL2-4B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
| `InternVLForConditionalGeneration` | InternVL 3.0 (HF format) | T + I<sup>E+</sup> + V<sup>E+</sup> | `OpenGVLab/InternVL3-1B-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
|
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
|
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ | ✅︎ |
|
||||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||||
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
|
| `Llama_Nemotron_Nano_VL` | Llama Nemotron Nano VL | T + I<sup>E+</sup> | `nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1` | ✅︎ | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@ -683,6 +683,37 @@ def run_keye_vl(questions: list[str], modality: str) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Keye-VL-1.5
|
||||||
|
def run_keye_vl1_5(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
|
model_name = "Kwai-Keye/Keye-VL-1.5-8B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
max_model_len=8192,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={modality: 1},
|
||||||
|
)
|
||||||
|
|
||||||
|
if modality == "image":
|
||||||
|
placeholder = "<|image_pad|>"
|
||||||
|
elif modality == "video":
|
||||||
|
placeholder = "<|video_pad|>"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
(
|
||||||
|
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||||||
|
f"{question}<|im_end|>\n"
|
||||||
|
"<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
for question in questions
|
||||||
|
]
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompts=prompts,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Kimi-VL
|
# Kimi-VL
|
||||||
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData:
|
||||||
assert modality == "image"
|
assert modality == "image"
|
||||||
@ -1648,6 +1679,7 @@ model_example_map = {
|
|||||||
"interns1": run_interns1,
|
"interns1": run_interns1,
|
||||||
"internvl_chat": run_internvl,
|
"internvl_chat": run_internvl,
|
||||||
"keye_vl": run_keye_vl,
|
"keye_vl": run_keye_vl,
|
||||||
|
"keye_vl1_5": run_keye_vl1_5,
|
||||||
"kimi_vl": run_kimi_vl,
|
"kimi_vl": run_kimi_vl,
|
||||||
"llama4": run_llama4,
|
"llama4": run_llama4,
|
||||||
"llava": run_llava,
|
"llava": run_llava,
|
||||||
|
|||||||
@ -542,6 +542,43 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
|
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=model_name,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{"type": "text", "text": question},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
image_data = [fetch_image(url) for url in image_urls]
|
||||||
|
|
||||||
|
return ModelRequestData(
|
||||||
|
engine_args=engine_args,
|
||||||
|
prompt=prompt,
|
||||||
|
image_data=image_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
|
||||||
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
|
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||||
|
|
||||||
@ -1209,6 +1246,7 @@ model_example_map = {
|
|||||||
"interns1": load_interns1,
|
"interns1": load_interns1,
|
||||||
"internvl_chat": load_internvl,
|
"internvl_chat": load_internvl,
|
||||||
"keye_vl": load_keye_vl,
|
"keye_vl": load_keye_vl,
|
||||||
|
"keye_vl1_5": load_keye_vl1_5,
|
||||||
"kimi_vl": load_kimi_vl,
|
"kimi_vl": load_kimi_vl,
|
||||||
"llama4": load_llama4,
|
"llama4": load_llama4,
|
||||||
"llava": load_llava,
|
"llava": load_llava,
|
||||||
|
|||||||
@ -293,6 +293,7 @@ def _test_processing_correctness_one(
|
|||||||
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
|
"OpenGVLab/InternVL3_5-GPT-OSS-20B-A4B-Preview",
|
||||||
"OpenGVLab/InternVL3_5-30B-A3B",
|
"OpenGVLab/InternVL3_5-30B-A3B",
|
||||||
"Kwai-Keye/Keye-VL-8B-Preview",
|
"Kwai-Keye/Keye-VL-8B-Preview",
|
||||||
|
"Kwai-Keye/Keye-VL-1_5-8B",
|
||||||
"moonshotai/Kimi-VL-A3B-Instruct",
|
"moonshotai/Kimi-VL-A3B-Instruct",
|
||||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
"llava-hf/llava-1.5-7b-hf",
|
"llava-hf/llava-1.5-7b-hf",
|
||||||
|
|||||||
@ -438,6 +438,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
|||||||
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
|
"InternVLForConditionalGeneration": _HfExamplesInfo("OpenGVLab/InternVL3-1B-hf"), # noqa: E501
|
||||||
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
|
"KeyeForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-8B-Preview", # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"KeyeVL1_5ForConditionalGeneration": _HfExamplesInfo("Kwai-Keye/Keye-VL-1_5-8B", # noqa: E501
|
||||||
|
trust_remote_code=True),
|
||||||
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
"KimiVLForConditionalGeneration": _HfExamplesInfo("moonshotai/Kimi-VL-A3B-Instruct", # noqa: E501
|
||||||
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
extras={"thinking": "moonshotai/Kimi-VL-A3B-Thinking"}, # noqa: E501
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
|||||||
@ -402,6 +402,15 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
context_len=context_len,
|
context_len=context_len,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
)
|
)
|
||||||
|
elif "KeyeVL1_5" in hf_config.model_type:
|
||||||
|
return cls._keye_get_input_positions_tensor(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
hf_config=hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
context_len=context_len,
|
||||||
|
seq_len=seq_len,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return cls._vl_get_input_positions_tensor(
|
return cls._vl_get_input_positions_tensor(
|
||||||
input_tokens=input_tokens,
|
input_tokens=input_tokens,
|
||||||
@ -636,6 +645,126 @@ class MRotaryEmbedding(RotaryEmbedding):
|
|||||||
len(input_tokens)).item()
|
len(input_tokens)).item()
|
||||||
return llm_positions, mrope_position_delta
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _keye_get_input_positions_tensor(
|
||||||
|
cls,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||||
|
video_grid_thw: Union[list[list[int]], torch.Tensor],
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: Optional[int] = None,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||||
|
video_grid_thw = video_grid_thw[0]
|
||||||
|
"""Get mrope input positions and delta value (Keye series)."""
|
||||||
|
|
||||||
|
def split_thw(
|
||||||
|
grid_thw: Union[torch.Tensor, list[int]]) -> list[list[int]]:
|
||||||
|
"""
|
||||||
|
Split grid_thw along the t dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [1, h, w] rows, repeated t times for each original row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(grid_thw, list):
|
||||||
|
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||||
|
|
||||||
|
if grid_thw.numel() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||||
|
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||||
|
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||||
|
return out.tolist()
|
||||||
|
|
||||||
|
video_grid_thw = split_thw(video_grid_thw)
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_token_id = hf_config.video_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
image_nums = len(image_grid_thw)
|
||||||
|
frame_nums = len(video_grid_thw)
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_frames = image_nums, frame_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + frame_nums):
|
||||||
|
if remain_images > 0:
|
||||||
|
try:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if remain_frames > 0:
|
||||||
|
try:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_index += 1
|
||||||
|
remain_frames -= 1
|
||||||
|
ed = ed_video
|
||||||
|
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = \
|
||||||
|
t, h // spatial_merge_size, w // spatial_merge_size
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
|
||||||
|
-1, llm_grid_h * llm_grid_w)).long().flatten()
|
||||||
|
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
|
||||||
|
llm_grid_t, -1, llm_grid_w).flatten()
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
|
||||||
|
llm_grid_t, llm_grid_h, -1).flatten()
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(
|
||||||
|
llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 -
|
||||||
|
len(input_tokens)).item()
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _vl_get_input_positions_tensor(
|
def _vl_get_input_positions_tensor(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Any, Literal, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -57,16 +58,13 @@ from .vision import get_vit_attn_backend
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
_MAX_FRAMES_PER_VIDEO = 16
|
|
||||||
_MAX_IMAGE_SIZE = 9999999
|
|
||||||
|
|
||||||
|
|
||||||
def smart_resize(
|
def smart_resize(
|
||||||
height: int,
|
height: int,
|
||||||
width: int,
|
width: int,
|
||||||
factor: int = 28,
|
factor: int,
|
||||||
min_pixels: int = 28 * 28 * 130,
|
min_pixels: int,
|
||||||
max_pixels: int = 28 * 28 * 1280,
|
max_pixels: int,
|
||||||
):
|
):
|
||||||
if height < factor:
|
if height < factor:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -887,9 +885,9 @@ class Projector(nn.Module):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_features: torch.Tensor,
|
image_features: Union[torch.Tensor, list[torch.Tensor]],
|
||||||
image_grid_thw: list[tuple[int, int, int]],
|
image_grid_thw: list[tuple[int, int, int]],
|
||||||
) -> torch.Tensor:
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
m1, m2 = self.merge_kernel_size
|
m1, m2 = self.merge_kernel_size
|
||||||
if isinstance(image_features, (list, tuple)):
|
if isinstance(image_features, (list, tuple)):
|
||||||
processed_features = list()
|
processed_features = list()
|
||||||
@ -986,6 +984,12 @@ class KeyeMultiModalDataParser(MultiModalDataParser):
|
|||||||
|
|
||||||
class KeyeProcessingInfo(BaseProcessingInfo):
|
class KeyeProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
|
def get_max_image_size(self) -> int:
|
||||||
|
return 9999999 #_MAX_IMAGE_SIZE
|
||||||
|
|
||||||
|
def get_max_frame_per_video(self) -> int:
|
||||||
|
return 16 #_MAX_FRAMES_PER_VIDEO
|
||||||
|
|
||||||
def get_image_processor(self, **kwargs: object):
|
def get_image_processor(self, **kwargs: object):
|
||||||
return self.get_hf_processor(**kwargs).image_processor
|
return self.get_hf_processor(**kwargs).image_processor
|
||||||
|
|
||||||
@ -1077,8 +1081,8 @@ class KeyeProcessingInfo(BaseProcessingInfo):
|
|||||||
|
|
||||||
def get_image_size_with_most_features(self, ) -> ImageSize:
|
def get_image_size_with_most_features(self, ) -> ImageSize:
|
||||||
max_image_size, _ = self._get_vision_info(
|
max_image_size, _ = self._get_vision_info(
|
||||||
image_width=_MAX_IMAGE_SIZE,
|
image_width=self.get_max_image_size(),
|
||||||
image_height=_MAX_IMAGE_SIZE,
|
image_height=self.get_max_image_size(),
|
||||||
image_processor=None,
|
image_processor=None,
|
||||||
)
|
)
|
||||||
return max_image_size
|
return max_image_size
|
||||||
@ -1123,7 +1127,7 @@ class KeyeProcessingInfo(BaseProcessingInfo):
|
|||||||
max_image_tokens)
|
max_image_tokens)
|
||||||
max_frames_per_video = min(
|
max_frames_per_video = min(
|
||||||
max_total_frames // max(max_videos, 1),
|
max_total_frames // max(max_videos, 1),
|
||||||
_MAX_FRAMES_PER_VIDEO,
|
self.get_max_frame_per_video(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return max(max_frames_per_video, 1)
|
return max(max_frames_per_video, 1)
|
||||||
@ -1139,7 +1143,10 @@ class KeyeProcessingInfo(BaseProcessingInfo):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
|
_I = TypeVar("_I", bound=KeyeProcessingInfo)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeBaseDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
@ -1183,6 +1190,10 @@ class KeyeDummyInputsBuilder(BaseDummyInputsBuilder[KeyeProcessingInfo]):
|
|||||||
return mm_data
|
return mm_data
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
||||||
|
|
||||||
def _get_data_parser(self) -> MultiModalDataParser:
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
@ -1231,13 +1242,7 @@ class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
|||||||
return _keye_field_config(hf_inputs)
|
return _keye_field_config(hf_inputs)
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(
|
class BaseKeyeModule(nn.Module):
|
||||||
KeyeMultiModalProcessor,
|
|
||||||
info=KeyeProcessingInfo,
|
|
||||||
dummy_inputs=KeyeDummyInputsBuilder,
|
|
||||||
)
|
|
||||||
class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
|
||||||
SupportsPP):
|
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
"q_proj",
|
"q_proj",
|
||||||
@ -1264,6 +1269,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
|||||||
|
|
||||||
raise ValueError("Only image or video modality is supported")
|
raise ValueError("Only image or video modality is supported")
|
||||||
|
|
||||||
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
|
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||||
|
return None
|
||||||
|
return quant_config
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||||
@ -1278,7 +1288,8 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
|||||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
)
|
)
|
||||||
self.mlp_AR = Projector(
|
|
||||||
|
self.mlp_AR = self._build_projector(
|
||||||
config,
|
config,
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
@ -1294,13 +1305,287 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
@abstractmethod
|
||||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
def _build_projector(self,
|
||||||
return None
|
text_config: PretrainedConfig,
|
||||||
return quant_config
|
vision_config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
|
raise ValueError("Need projector")
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
|
def _process_image_input(self,
|
||||||
name: str) -> torch.Tensor:
|
image_input: Any) -> tuple[torch.Tensor, ...]:
|
||||||
|
siglip_position_ids = list()
|
||||||
|
image_grid_hws = list()
|
||||||
|
sample_indices = list()
|
||||||
|
cu_seqlens = [0]
|
||||||
|
|
||||||
|
image_grid_thw = image_input["image_grid_thw"]
|
||||||
|
assert image_grid_thw.ndim == 2
|
||||||
|
|
||||||
|
for idx, thaw in enumerate(image_grid_thw):
|
||||||
|
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
|
||||||
|
numel = np.prod(thw_tuple)
|
||||||
|
image_grid_hws.append(thw_tuple)
|
||||||
|
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
||||||
|
siglip_position_ids.append(image_position_ids)
|
||||||
|
sample_indices.append(torch.full((numel, ), idx,
|
||||||
|
dtype=torch.int64))
|
||||||
|
cu_seqlens.append(cu_seqlens[-1] + numel)
|
||||||
|
|
||||||
|
if image_input["type"] == "image_embeds":
|
||||||
|
raise ValueError(
|
||||||
|
"Image embeddings are not supported for this processing path.")
|
||||||
|
else:
|
||||||
|
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
||||||
|
siglip_position_ids = torch.concat(siglip_position_ids,
|
||||||
|
dim=0).to(pixel_values.device)
|
||||||
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
||||||
|
pixel_values.device)
|
||||||
|
sample_indices = torch.concat(sample_indices,
|
||||||
|
dim=0).to(pixel_values.device)
|
||||||
|
|
||||||
|
image_embeds = self.visual(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_hws,
|
||||||
|
position_ids=siglip_position_ids,
|
||||||
|
vision_return_embed_list=False,
|
||||||
|
interpolate_pos_encoding=True,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
use_rope=True,
|
||||||
|
window_size=-1,
|
||||||
|
)
|
||||||
|
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
|
||||||
|
return image_embeds
|
||||||
|
|
||||||
|
def _process_video_embeds(
|
||||||
|
self,
|
||||||
|
video_type: Literal["video_embeds", "pixel_values_videos"],
|
||||||
|
video_grid_thw: list[torch.Tensor],
|
||||||
|
pixel_values_videos: Optional[torch.Tensor] = None
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
siglip_position_ids = list()
|
||||||
|
video_grid_hws = list()
|
||||||
|
sample_indices = list()
|
||||||
|
cu_seqlens = [0]
|
||||||
|
|
||||||
|
assert video_grid_thw.ndim == 2
|
||||||
|
for idx, sub_thw in enumerate(video_grid_thw):
|
||||||
|
thw_tuple = tuple(sub_thw.detach().cpu().numpy().tolist())
|
||||||
|
numel = np.prod(thw_tuple)
|
||||||
|
|
||||||
|
video_grid_hws.append(thw_tuple)
|
||||||
|
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
||||||
|
siglip_position_ids.append(video_position_ids)
|
||||||
|
sample_indices.append(torch.full((numel, ), idx,
|
||||||
|
dtype=torch.int64))
|
||||||
|
cu_seqlens.append(cu_seqlens[-1] + numel)
|
||||||
|
|
||||||
|
if video_type == "video_embeds":
|
||||||
|
raise ValueError(
|
||||||
|
"Video embeddings are not supported for this processing path.")
|
||||||
|
else:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
|
||||||
|
pixel_values_videos.device)
|
||||||
|
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
||||||
|
pixel_values_videos.device)
|
||||||
|
sample_indices = torch.concat(sample_indices,
|
||||||
|
dim=0).to(pixel_values_videos.device)
|
||||||
|
|
||||||
|
video_embeds = self.visual(
|
||||||
|
pixel_values=pixel_values_videos,
|
||||||
|
image_grid_thw=video_grid_hws,
|
||||||
|
position_ids=siglip_position_ids,
|
||||||
|
vision_return_embed_list=True,
|
||||||
|
interpolate_pos_encoding=True,
|
||||||
|
sample_indices=sample_indices,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
use_rope=True,
|
||||||
|
window_size=-1,
|
||||||
|
)
|
||||||
|
video_embeds = self.mlp_AR(video_embeds, video_grid_thw)
|
||||||
|
return video_embeds
|
||||||
|
|
||||||
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||||
|
modalities = {}
|
||||||
|
|
||||||
|
for input_key in kwargs:
|
||||||
|
if (input_key in ("pixel_values", "image_embeds")
|
||||||
|
and "images" not in modalities):
|
||||||
|
modalities["images"] = self._parse_and_validate_image_input(
|
||||||
|
**kwargs)
|
||||||
|
if (input_key in ("pixel_values_videos", "video_embeds")
|
||||||
|
and "videos" not in modalities):
|
||||||
|
modalities["videos"] = self._parse_and_validate_video_input(
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return modalities
|
||||||
|
|
||||||
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
return self.language_model
|
||||||
|
|
||||||
|
def get_multimodal_embeddings(
|
||||||
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||||
|
|
||||||
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||||
|
if not modalities:
|
||||||
|
return None
|
||||||
|
|
||||||
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||||
|
|
||||||
|
for modality in modalities:
|
||||||
|
if modality == "images":
|
||||||
|
image_input = modalities["images"]
|
||||||
|
vision_embeddings = self._process_image_input(image_input)
|
||||||
|
multimodal_embeddings += vision_embeddings
|
||||||
|
if modality == "videos":
|
||||||
|
video_input = modalities["videos"]
|
||||||
|
video_embeddings = self._process_video_input(video_input)
|
||||||
|
multimodal_embeddings += video_embeddings
|
||||||
|
return multimodal_embeddings
|
||||||
|
|
||||||
|
def get_input_embeddings(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||||
|
if multimodal_embeddings is not None:
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
multimodal_embeddings,
|
||||||
|
[
|
||||||
|
self.config.image_token_id,
|
||||||
|
self.config.video_token_id,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def get_input_embeddings_v0(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
image_input: Optional[Any] = None,
|
||||||
|
video_input: Optional[Any] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||||
|
if image_input is not None:
|
||||||
|
image_embeds = self._process_image_input(image_input)
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
image_embeds,
|
||||||
|
placeholder_token_id=self.config.image_token_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if video_input is not None:
|
||||||
|
video_embeds = self._process_video_input(video_input)
|
||||||
|
inputs_embeds = merge_multimodal_embeddings(
|
||||||
|
input_ids,
|
||||||
|
inputs_embeds,
|
||||||
|
video_embeds,
|
||||||
|
placeholder_token_id=self.config.video_token_id,
|
||||||
|
)
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs: object,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
"""Run forward pass for Keye-VL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||||
|
batch.
|
||||||
|
positions: Flattened (concatenated) position ids corresponding to a
|
||||||
|
batch.
|
||||||
|
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
||||||
|
opensource models), the shape will be `(3, seq_len)`,
|
||||||
|
otherwise it will be `(seq_len,).
|
||||||
|
pixel_values: Pixel values to be fed to a model.
|
||||||
|
`None` if no images are passed.
|
||||||
|
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
||||||
|
`None` if no images are passed.
|
||||||
|
pixel_values_videos: Pixel values of videos to be fed to a model.
|
||||||
|
`None` if no videos are passed.
|
||||||
|
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
||||||
|
`None` if no videos are passed.
|
||||||
|
"""
|
||||||
|
if intermediate_tensors is not None:
|
||||||
|
inputs_embeds = None
|
||||||
|
|
||||||
|
elif inputs_embeds is None:
|
||||||
|
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||||
|
video_input = self._parse_and_validate_video_input(**kwargs)
|
||||||
|
if image_input is None and video_input is None:
|
||||||
|
inputs_embeds = None
|
||||||
|
else:
|
||||||
|
if uses_mrope(self.config):
|
||||||
|
assert positions.ndim == 2 and positions.size(0) == 3, (
|
||||||
|
"multimodal section rotary embedding requires "
|
||||||
|
f"(3, seq_len) positions, but got {positions.size()}")
|
||||||
|
inputs_embeds = self.get_input_embeddings_v0(
|
||||||
|
input_ids,
|
||||||
|
image_input=image_input,
|
||||||
|
video_input=video_input,
|
||||||
|
)
|
||||||
|
input_ids = None
|
||||||
|
|
||||||
|
hidden_states = self.language_model.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
positions=positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
return self.language_model.compute_logits(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(self)
|
||||||
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
|
||||||
|
def get_mm_mapping(self) -> MultiModelKeys:
|
||||||
|
"""Get the module prefix in multimodal models."""
|
||||||
|
return MultiModelKeys.from_string_field(
|
||||||
|
language_model="language_model",
|
||||||
|
connector="mlp_AR.",
|
||||||
|
tower_model="visual.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
KeyeMultiModalProcessor,
|
||||||
|
info=KeyeProcessingInfo,
|
||||||
|
dummy_inputs=KeyeDummyInputsBuilder,
|
||||||
|
)
|
||||||
|
class KeyeForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||||
|
SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
|
def _build_projector(self,
|
||||||
|
text_config: PretrainedConfig,
|
||||||
|
vision_config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
|
return Projector(text_config, vision_config, quant_config, prefix)
|
||||||
|
|
||||||
|
def _validate_and_reshape_mm_tensor(
|
||||||
|
self, mm_input: NestedTensors,
|
||||||
|
name: str) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
f"Got type: {type(mm_input)}")
|
f"Got type: {type(mm_input)}")
|
||||||
@ -1388,257 +1673,12 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
|
|||||||
video_grid_thw=video_grid_thw,
|
video_grid_thw=video_grid_thw,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _process_image_input(
|
|
||||||
self, image_input: KeyeImageInputs) -> tuple[torch.Tensor, ...]:
|
|
||||||
siglip_position_ids = list()
|
|
||||||
image_grid_hws = list()
|
|
||||||
sample_indices = list()
|
|
||||||
cu_seqlens = [0]
|
|
||||||
|
|
||||||
image_grid_thw = image_input["image_grid_thw"]
|
|
||||||
assert image_grid_thw.ndim == 2
|
|
||||||
|
|
||||||
for idx, thaw in enumerate(image_grid_thw):
|
|
||||||
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
|
|
||||||
numel = np.prod(thw_tuple)
|
|
||||||
image_grid_hws.append(thw_tuple)
|
|
||||||
image_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
|
||||||
siglip_position_ids.append(image_position_ids)
|
|
||||||
sample_indices.append(torch.full((numel, ), idx,
|
|
||||||
dtype=torch.int64))
|
|
||||||
cu_seqlens.append(cu_seqlens[-1] + numel)
|
|
||||||
|
|
||||||
if image_input["type"] == "image_embeds":
|
|
||||||
raise ValueError(
|
|
||||||
"Image embeddings are not supported for this processing path.")
|
|
||||||
else:
|
|
||||||
pixel_values = image_input["pixel_values"].type(self.visual.dtype)
|
|
||||||
siglip_position_ids = torch.concat(siglip_position_ids,
|
|
||||||
dim=0).to(pixel_values.device)
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
|
||||||
pixel_values.device)
|
|
||||||
sample_indices = torch.concat(sample_indices,
|
|
||||||
dim=0).to(pixel_values.device)
|
|
||||||
|
|
||||||
image_embeds = self.visual(
|
|
||||||
pixel_values=pixel_values,
|
|
||||||
image_grid_thw=image_grid_hws,
|
|
||||||
position_ids=siglip_position_ids,
|
|
||||||
vision_return_embed_list=False,
|
|
||||||
interpolate_pos_encoding=True,
|
|
||||||
sample_indices=sample_indices,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
use_rope=True,
|
|
||||||
window_size=-1,
|
|
||||||
)
|
|
||||||
image_embeds = tuple(self.mlp_AR(image_embeds, image_grid_thw))
|
|
||||||
return image_embeds
|
|
||||||
|
|
||||||
def _process_video_input(
|
def _process_video_input(
|
||||||
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
|
self, video_input: KeyeVideoInputs) -> tuple[torch.Tensor, ...]:
|
||||||
siglip_position_ids = list()
|
video_type = video_input["type"]
|
||||||
video_grid_hws = list()
|
|
||||||
sample_indices = list()
|
|
||||||
cu_seqlens = [0]
|
|
||||||
|
|
||||||
video_grid_thw = video_input["video_grid_thw"]
|
video_grid_thw = video_input["video_grid_thw"]
|
||||||
assert video_grid_thw.ndim == 2
|
pixel_values_videos = video_input.get("pixel_values_videos", None)
|
||||||
|
|
||||||
for idx, thaw in enumerate(video_grid_thw):
|
return tuple(
|
||||||
thw_tuple = tuple(thaw.detach().cpu().numpy().tolist())
|
self._process_video_embeds(video_type, video_grid_thw,
|
||||||
numel = np.prod(thw_tuple)
|
pixel_values_videos))
|
||||||
|
|
||||||
video_grid_hws.append(thw_tuple)
|
|
||||||
video_position_ids = torch.arange(numel) % np.prod(thw_tuple[1:])
|
|
||||||
siglip_position_ids.append(video_position_ids)
|
|
||||||
sample_indices.append(torch.full((numel, ), idx,
|
|
||||||
dtype=torch.int64))
|
|
||||||
cu_seqlens.append(cu_seqlens[-1] + numel)
|
|
||||||
|
|
||||||
if video_input["type"] == "video_embeds":
|
|
||||||
raise ValueError(
|
|
||||||
"Video embeddings are not supported for this processing path.")
|
|
||||||
else:
|
|
||||||
pixel_values_videos = video_input["pixel_values_videos"].type(
|
|
||||||
self.visual.dtype)
|
|
||||||
siglip_position_ids = torch.concat(siglip_position_ids, dim=0).to(
|
|
||||||
pixel_values_videos.device)
|
|
||||||
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32).to(
|
|
||||||
pixel_values_videos.device)
|
|
||||||
sample_indices = torch.concat(sample_indices,
|
|
||||||
dim=0).to(pixel_values_videos.device)
|
|
||||||
|
|
||||||
video_embeds = self.visual(
|
|
||||||
pixel_values=pixel_values_videos,
|
|
||||||
image_grid_thw=video_grid_hws,
|
|
||||||
position_ids=siglip_position_ids,
|
|
||||||
vision_return_embed_list=True,
|
|
||||||
interpolate_pos_encoding=True,
|
|
||||||
sample_indices=sample_indices,
|
|
||||||
cu_seqlens=cu_seqlens,
|
|
||||||
use_rope=True,
|
|
||||||
window_size=-1,
|
|
||||||
)
|
|
||||||
video_embeds = tuple(self.mlp_AR(video_embeds, video_grid_thw))
|
|
||||||
return video_embeds
|
|
||||||
|
|
||||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
|
||||||
modalities = {}
|
|
||||||
|
|
||||||
for input_key in kwargs:
|
|
||||||
if (input_key in ("pixel_values", "image_embeds")
|
|
||||||
and "images" not in modalities):
|
|
||||||
modalities["images"] = self._parse_and_validate_image_input(
|
|
||||||
**kwargs)
|
|
||||||
if (input_key in ("pixel_values_videos", "video_embeds")
|
|
||||||
and "videos" not in modalities):
|
|
||||||
modalities["videos"] = self._parse_and_validate_video_input(
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
return modalities
|
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
|
||||||
return self.language_model
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(
|
|
||||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
||||||
|
|
||||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
|
||||||
if not modalities:
|
|
||||||
return None
|
|
||||||
|
|
||||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
|
||||||
|
|
||||||
for modality in modalities:
|
|
||||||
if modality == "images":
|
|
||||||
image_input = modalities["images"]
|
|
||||||
vision_embeddings = self._process_image_input(image_input)
|
|
||||||
multimodal_embeddings += vision_embeddings
|
|
||||||
if modality == "videos":
|
|
||||||
video_input = modalities["videos"]
|
|
||||||
video_embeddings = self._process_video_input(video_input)
|
|
||||||
multimodal_embeddings += video_embeddings
|
|
||||||
return multimodal_embeddings
|
|
||||||
|
|
||||||
def get_input_embeddings(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
|
||||||
if multimodal_embeddings is not None:
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
multimodal_embeddings,
|
|
||||||
[
|
|
||||||
self.config.image_token_id,
|
|
||||||
self.config.video_token_id,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def get_input_embeddings_v0(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
image_input: Optional[KeyeImagePixelInputs] = None,
|
|
||||||
video_input: Optional[KeyeVideoPixelInputs] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
|
||||||
if image_input is not None:
|
|
||||||
image_embeds = self._process_image_input(image_input)
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
image_embeds,
|
|
||||||
placeholder_token_id=self.config.image_token_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if video_input is not None:
|
|
||||||
video_embeds = self._process_video_input(video_input)
|
|
||||||
inputs_embeds = merge_multimodal_embeddings(
|
|
||||||
input_ids,
|
|
||||||
inputs_embeds,
|
|
||||||
video_embeds,
|
|
||||||
placeholder_token_id=self.config.video_token_id,
|
|
||||||
)
|
|
||||||
return inputs_embeds
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
**kwargs: object,
|
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
||||||
"""Run forward pass for Qwen2-VL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
||||||
batch.
|
|
||||||
positions: Flattened (concatenated) position ids corresponding to a
|
|
||||||
batch.
|
|
||||||
**NOTE**: If mrope is enabled (default setting for Qwen2-VL
|
|
||||||
opensource models), the shape will be `(3, seq_len)`,
|
|
||||||
otherwise it will be `(seq_len,).
|
|
||||||
pixel_values: Pixel values to be fed to a model.
|
|
||||||
`None` if no images are passed.
|
|
||||||
image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
|
|
||||||
`None` if no images are passed.
|
|
||||||
pixel_values_videos: Pixel values of videos to be fed to a model.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
|
|
||||||
`None` if no videos are passed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if intermediate_tensors is not None:
|
|
||||||
inputs_embeds = None
|
|
||||||
|
|
||||||
elif inputs_embeds is None:
|
|
||||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
|
||||||
video_input = self._parse_and_validate_video_input(**kwargs)
|
|
||||||
|
|
||||||
if image_input is None and video_input is None:
|
|
||||||
inputs_embeds = None
|
|
||||||
else:
|
|
||||||
if uses_mrope(self.config):
|
|
||||||
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
||||||
"multimodal section rotary embedding requires "
|
|
||||||
f"(3, seq_len) positions, but got {positions.size()}")
|
|
||||||
inputs_embeds = self.get_input_embeddings_v0(
|
|
||||||
input_ids,
|
|
||||||
image_input=image_input,
|
|
||||||
video_input=video_input,
|
|
||||||
)
|
|
||||||
input_ids = None
|
|
||||||
|
|
||||||
hidden_states = self.language_model.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
return self.language_model.compute_logits(hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
|
||||||
torch.Tensor]]) -> set[str]:
|
|
||||||
|
|
||||||
loader = AutoWeightsLoader(self)
|
|
||||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
||||||
|
|
||||||
def get_mm_mapping(self) -> MultiModelKeys:
|
|
||||||
"""Get the module prefix in multimodal models."""
|
|
||||||
return MultiModelKeys.from_string_field(
|
|
||||||
language_model="language_model",
|
|
||||||
connector="visual.",
|
|
||||||
tower_model="mlp_AR.",
|
|
||||||
)
|
|
||||||
|
|||||||
601
vllm/model_executor/models/keye_vl1_5.py
Normal file
601
vllm/model_executor/models/keye_vl1_5.py
Normal file
@ -0,0 +1,601 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import itertools
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from functools import partial
|
||||||
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from transformers.activations import GELUActivation
|
||||||
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
|
RowParallelLinear)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||||
|
from vllm.multimodal.inputs import (ImageItem, ModalityData,
|
||||||
|
MultiModalFieldConfig,
|
||||||
|
MultiModalKwargsItems, VideoItem)
|
||||||
|
from vllm.multimodal.parse import (DictEmbeddingItems, ModalityDataItems,
|
||||||
|
MultiModalDataItems, MultiModalDataParser)
|
||||||
|
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||||
|
PromptUpdateDetails)
|
||||||
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
|
||||||
|
from .keye import (BaseKeyeModule, BaseMultiModalProcessor,
|
||||||
|
KeyeBaseDummyInputsBuilder, KeyeProcessingInfo)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def split_thw(grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Split grid_thw in t dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_thw: [N, 3] tensor of [t, h, w]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[Σt, 3] tensor where each row is [1, h, w]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> grid_thw = torch.tensor([[2, 3, 4], [1, 5, 6]])
|
||||||
|
>>> split_thw(grid_thw)
|
||||||
|
tensor([[1, 3, 4],
|
||||||
|
[1, 3, 4],
|
||||||
|
[1, 5, 6]])
|
||||||
|
"""
|
||||||
|
t = grid_thw[:, 0]
|
||||||
|
h_w = grid_thw[:, 1:]
|
||||||
|
ones = torch.ones_like(h_w[:, :1])
|
||||||
|
return torch.cat([ones, h_w], dim=1).repeat_interleave(t, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def get_num_patches(grid_thw: torch.Tensor, num_frames: Union[list[int],
|
||||||
|
torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Return num_patches per video.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t: tensor with shape [N, ...] where each item is a list/tensor
|
||||||
|
cu_seqlens: list indicating the boundaries of groups
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of ints representing the sum of products for each group
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # Suppose there are 2 videos with a total of 3 grids
|
||||||
|
>>> grid_thw = torch.tensor([[2, 2, 2], # grid 0: 2*2*2=8 patches
|
||||||
|
... [2, 2, 2], # grid 1: 2*2*2=8 patches
|
||||||
|
... [1, 1, 1]]) # grid 2: 1*1*1=1 patches
|
||||||
|
>>> num_frames = [2, 1] # The first video contains 2 grids,
|
||||||
|
the second contains 1 grid.
|
||||||
|
>>> get_num_patches(grid_thw, num_frames)
|
||||||
|
tensor([16, 1]) # Total patches for first video: 8+8=16,
|
||||||
|
second video: 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(grid_thw.shape) == 2
|
||||||
|
if isinstance(num_frames, torch.Tensor):
|
||||||
|
num_frames = num_frames.clone().tolist()
|
||||||
|
|
||||||
|
num_grids_per_frame = grid_thw.prod(dim=1)
|
||||||
|
start_idx_per_video = [0, *itertools.accumulate(num_frames)]
|
||||||
|
num_patches = [
|
||||||
|
num_grids_per_frame[start_idx_per_video[i]:start_idx_per_video[i + 1]].
|
||||||
|
sum() for i in range(len(num_frames))
|
||||||
|
]
|
||||||
|
return torch.stack(num_patches) if num_patches else torch.zeros(
|
||||||
|
0, dtype=grid_thw.dtype, device=grid_thw.device)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5ImagePixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
|
- np: Number of patches
|
||||||
|
- c: Number of channels
|
||||||
|
- ps: Patch size
|
||||||
|
- ni: Number of images
|
||||||
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
|
"""
|
||||||
|
type: Literal["pixel_values"]
|
||||||
|
|
||||||
|
pixel_values: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||||
|
|
||||||
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5ImageEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- nf: Number of image features
|
||||||
|
- hs: Hidden size (must match the hidden size of language model
|
||||||
|
backbone)
|
||||||
|
- ni: Number of images
|
||||||
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
|
"""
|
||||||
|
type: Literal["image_embeds"]
|
||||||
|
image_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
||||||
|
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||||
|
|
||||||
|
|
||||||
|
KeyeVL1_5ImageInputs = Union[KeyeVL1_5ImagePixelInputs,
|
||||||
|
KeyeVL1_5ImageEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5VideoPixelInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- b: Batch size
|
||||||
|
- np: Number of patches
|
||||||
|
- c: Number of channels
|
||||||
|
- ps: Patch size
|
||||||
|
- ni: Number of images
|
||||||
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
|
"""
|
||||||
|
type: Literal["pixel_values_videos"]
|
||||||
|
pixel_values_videos: Annotated[
|
||||||
|
torch.Tensor,
|
||||||
|
TensorShape("np", 3, "ps", "ps", dynamic_dims={"np"})]
|
||||||
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||||
|
|
||||||
|
num_frames: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5VideoEmbeddingInputs(TensorSchema):
|
||||||
|
"""
|
||||||
|
Dimensions:
|
||||||
|
- nf: Number of video features
|
||||||
|
- hs: Hidden size (must match the hidden size of language model
|
||||||
|
backbone)
|
||||||
|
- nv: Number of videos
|
||||||
|
- g: Grid dimensions (3 for t, h, w)
|
||||||
|
"""
|
||||||
|
type: Literal["video_embeds"]
|
||||||
|
video_embeds: Annotated[torch.Tensor, TensorShape("nf", "hs")]
|
||||||
|
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||||
|
num_frames: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
KeyeVL1_5VideoInputs = Union[KeyeVL1_5VideoPixelInputs,
|
||||||
|
KeyeVL1_5VideoEmbeddingInputs]
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5Projector(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_config: PretrainedConfig,
|
||||||
|
vision_config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.text_config = text_config
|
||||||
|
self.vision_config = vision_config
|
||||||
|
self.merge_kernel_size = (2, 2)
|
||||||
|
|
||||||
|
self.hidden_size = (self.vision_config.hidden_size *
|
||||||
|
self.merge_kernel_size[0] *
|
||||||
|
self.merge_kernel_size[1])
|
||||||
|
|
||||||
|
self.pre_norm = torch.nn.LayerNorm(self.hidden_size, eps=1e-05)
|
||||||
|
self.act = GELUActivation()
|
||||||
|
|
||||||
|
self.linear_1 = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear_1",
|
||||||
|
)
|
||||||
|
self.linear_2 = RowParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.text_config.hidden_size,
|
||||||
|
bias=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.linear_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
image_features: Union[torch.Tensor, tuple[torch.Tensor],
|
||||||
|
list[torch.Tensor]],
|
||||||
|
image_grid_thw: list[tuple[int, int, int]],
|
||||||
|
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
m1, m2 = self.merge_kernel_size
|
||||||
|
if isinstance(image_features, (list, tuple)):
|
||||||
|
processed_features = list()
|
||||||
|
for image_feature, image_grid in zip(image_features,
|
||||||
|
image_grid_thw):
|
||||||
|
t, h, w = image_grid
|
||||||
|
image_feature = rearrange(
|
||||||
|
image_feature,
|
||||||
|
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
|
||||||
|
t=t,
|
||||||
|
h=h // m1,
|
||||||
|
p1=m1,
|
||||||
|
w=w // m2,
|
||||||
|
p2=m2,
|
||||||
|
)
|
||||||
|
image_feature = self.pre_norm(image_feature)
|
||||||
|
hidden_states, _ = self.linear_1(image_feature)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states, _ = self.linear_2(hidden_states)
|
||||||
|
processed_features.append(hidden_states)
|
||||||
|
|
||||||
|
return processed_features
|
||||||
|
|
||||||
|
dims = image_features.shape[:-1]
|
||||||
|
dim = image_features.shape[-1]
|
||||||
|
image_features = image_features.view(np.prod(dims), dim)
|
||||||
|
hidden_states = self.pre_norm(image_features.view(
|
||||||
|
-1, self.hidden_size))
|
||||||
|
hidden_states = self.linear_1(hidden_states)
|
||||||
|
hidden_states = self.act(hidden_states)
|
||||||
|
hidden_states = self.linear_2(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states.view(*dims, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
|
||||||
|
|
||||||
|
def get_max_frame_per_video(self) -> int:
|
||||||
|
return 2048
|
||||||
|
|
||||||
|
def get_supported_mm_limits(self, ) -> Mapping[str, Optional[int]]:
|
||||||
|
return {"image": None, "video": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def _keye_field_config(hf_inputs: Mapping[str, torch.Tensor], ):
|
||||||
|
image_grid_thw = hf_inputs.get("image_grid_thw",
|
||||||
|
torch.empty((0, 3), dtype=torch.int64))
|
||||||
|
image_grid_sizes = image_grid_thw.prod(-1)
|
||||||
|
|
||||||
|
video_grid_thw = hf_inputs.get("video_grid_thw",
|
||||||
|
torch.empty((0, 3), dtype=torch.int64))
|
||||||
|
video_grid_thw = split_thw(video_grid_thw)
|
||||||
|
num_frames = hf_inputs.get("num_frames",
|
||||||
|
video_grid_thw[:, 0]).clone().tolist()
|
||||||
|
|
||||||
|
video_num_patches = get_num_patches(video_grid_thw, num_frames)
|
||||||
|
|
||||||
|
video_num_grids = []
|
||||||
|
if len(num_frames) > 0:
|
||||||
|
i = 0
|
||||||
|
j = 1
|
||||||
|
cur_frames = num_frames[i]
|
||||||
|
for t, _, _ in video_grid_thw.tolist():
|
||||||
|
cur_frames -= t
|
||||||
|
if cur_frames == 0:
|
||||||
|
video_num_grids.append(j)
|
||||||
|
i += 1
|
||||||
|
if i < len(num_frames):
|
||||||
|
cur_frames = num_frames[i]
|
||||||
|
j = 1
|
||||||
|
else:
|
||||||
|
j += 1
|
||||||
|
video_num_grids = torch.tensor(video_num_grids)
|
||||||
|
return dict(pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", image_grid_sizes),
|
||||||
|
image_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"image", image_grid_sizes),
|
||||||
|
image_grid_thw=MultiModalFieldConfig.batched("image"),
|
||||||
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_num_patches),
|
||||||
|
video_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_num_patches),
|
||||||
|
video_grid_thw=MultiModalFieldConfig.flat_from_sizes(
|
||||||
|
"video", video_num_grids),
|
||||||
|
num_frames=MultiModalFieldConfig.batched("video"))
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5MultiModalDataParser(MultiModalDataParser):
|
||||||
|
|
||||||
|
def _parse_image_data(
|
||||||
|
self,
|
||||||
|
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
|
||||||
|
) -> ModalityDataItems[Any, Any]:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return DictEmbeddingItems(
|
||||||
|
data,
|
||||||
|
modality="image",
|
||||||
|
required_fields={
|
||||||
|
"image_embeds",
|
||||||
|
"image_grid_thw",
|
||||||
|
},
|
||||||
|
fields_factory=_keye_field_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._parse_image_data(data)
|
||||||
|
|
||||||
|
def _parse_video_data(
|
||||||
|
self,
|
||||||
|
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
|
||||||
|
) -> ModalityDataItems[Any, Any]:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return DictEmbeddingItems(
|
||||||
|
data,
|
||||||
|
modality="video",
|
||||||
|
required_fields={
|
||||||
|
"video_embeds",
|
||||||
|
"video_grid_thw",
|
||||||
|
},
|
||||||
|
fields_factory=_keye_field_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super()._parse_video_data(data)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5MultiModalProcessor(
|
||||||
|
BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
|
||||||
|
|
||||||
|
def _get_data_parser(self) -> MultiModalDataParser:
|
||||||
|
return KeyeVL1_5MultiModalDataParser()
|
||||||
|
|
||||||
|
def _get_prompt_updates(
|
||||||
|
self,
|
||||||
|
mm_items: MultiModalDataItems,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, Any],
|
||||||
|
out_mm_kwargs: MultiModalKwargsItems,
|
||||||
|
) -> Sequence[PromptUpdate]:
|
||||||
|
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||||
|
image_processor = self.info.get_image_processor(
|
||||||
|
**hf_processor_mm_kwargs)
|
||||||
|
tokenizer = self.info.get_tokenizer()
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
image_token_id = vocab[hf_processor.image_token]
|
||||||
|
video_token_id = vocab[hf_processor.video_token]
|
||||||
|
placeholder = {"image": image_token_id, "video": video_token_id}
|
||||||
|
merge_length = image_processor.merge_size**2
|
||||||
|
|
||||||
|
out_mm_kwargs_data = out_mm_kwargs.get_data()
|
||||||
|
frame_types: list[torch.Tensor] = \
|
||||||
|
hf_processor_mm_kwargs.get("frame_types", None)
|
||||||
|
timestamps: list[torch.Tensor] = \
|
||||||
|
hf_processor_mm_kwargs.get("timestamps", None)
|
||||||
|
num_videos = mm_items.get_count("video", strict=False)
|
||||||
|
|
||||||
|
if frame_types is None:
|
||||||
|
frame_types = [None] * num_videos
|
||||||
|
assert len(frame_types) == num_videos, \
|
||||||
|
f"Number of frame_types={len(frame_types)} " \
|
||||||
|
f"doesn't equal to number of videos={num_videos}"
|
||||||
|
if timestamps is None:
|
||||||
|
timestamps = [None] * num_videos
|
||||||
|
assert len(timestamps) == num_videos, \
|
||||||
|
f"Number of timestamps={len(timestamps)} " \
|
||||||
|
f"doesn't equal to number of videos={num_videos}"
|
||||||
|
|
||||||
|
video_grid_thw = out_mm_kwargs_data.get(
|
||||||
|
'video_grid_thw', torch.empty((0, 3), dtype=torch.int64))
|
||||||
|
num_frames = out_mm_kwargs_data.get(
|
||||||
|
'num_frames', torch.tensor([], dtype=torch.int64))
|
||||||
|
|
||||||
|
assert len(num_frames) == num_videos, \
|
||||||
|
f"Size of num_frames={len(num_frames)} " \
|
||||||
|
f"doesn't equal to number of videos={num_videos}"
|
||||||
|
|
||||||
|
video_grid_hws = split_thw(video_grid_thw)
|
||||||
|
assert int(num_frames.sum().tolist()) == video_grid_hws.shape[0], (
|
||||||
|
f"The first dimension of `video_grid_hws`={video_grid_hws.shape[0]}"
|
||||||
|
f"doesn't equal to num of frames.")
|
||||||
|
|
||||||
|
cu_seqlens = torch.cumsum(torch.tensor([0] + num_frames.tolist()),
|
||||||
|
dim=-1)
|
||||||
|
|
||||||
|
def get_replacement_keye(item_idx: int, modality: str):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
item_idx(int): The item index of modality to replace
|
||||||
|
modality(str): The modality
|
||||||
|
"""
|
||||||
|
if modality == "image":
|
||||||
|
out_item = out_mm_kwargs[modality][item_idx]
|
||||||
|
grid_thw = out_item[f"{modality}_grid_thw"].data
|
||||||
|
assert isinstance(grid_thw, torch.Tensor)
|
||||||
|
|
||||||
|
num_tokens = int(grid_thw.prod()) // merge_length
|
||||||
|
return [image_token_id] * num_tokens
|
||||||
|
elif modality == "video":
|
||||||
|
placeholders = []
|
||||||
|
video_timestamps = timestamps[item_idx]
|
||||||
|
video_frame_types = frame_types[item_idx]
|
||||||
|
grid_thw = video_grid_hws[
|
||||||
|
cu_seqlens[item_idx]:cu_seqlens[item_idx + 1]]
|
||||||
|
|
||||||
|
nframes = grid_thw.shape[0]
|
||||||
|
|
||||||
|
if video_timestamps is None:
|
||||||
|
video_timestamps = [""] * nframes
|
||||||
|
else:
|
||||||
|
video_timestamps = [
|
||||||
|
format(ts, ".1f") for ts in video_timestamps
|
||||||
|
]
|
||||||
|
|
||||||
|
if video_frame_types is None:
|
||||||
|
video_frame_types = [0] * nframes
|
||||||
|
for i, sub_thw in enumerate(grid_thw):
|
||||||
|
s = f"{hf_processor.frame_token}{video_timestamps[i]}"
|
||||||
|
if video_frame_types[i] == 1:
|
||||||
|
s += hf_processor.fast_start
|
||||||
|
placeholders.extend(tokenizer.encode(s))
|
||||||
|
num_frame_tokens = int(sub_thw.prod()) // merge_length
|
||||||
|
placeholders.extend([video_token_id] * num_frame_tokens)
|
||||||
|
if video_frame_types[i] == 1:
|
||||||
|
placeholders.append(vocab[hf_processor.fast_end])
|
||||||
|
|
||||||
|
return PromptUpdateDetails.select_token_id(
|
||||||
|
placeholders, embed_token_id=video_token_id)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported modality {modality}")
|
||||||
|
|
||||||
|
return [
|
||||||
|
PromptReplacement(
|
||||||
|
modality=modality,
|
||||||
|
target=[placeholder[modality]],
|
||||||
|
replacement=partial(get_replacement_keye, modality=modality),
|
||||||
|
) for modality in ("image", "video")
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_mm_fields_config(
|
||||||
|
self,
|
||||||
|
hf_inputs: BatchFeature,
|
||||||
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
|
return _keye_field_config(hf_inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyeVL1_5DummyInputsBuilder(
|
||||||
|
KeyeBaseDummyInputsBuilder[KeyeVL1_5ProcessingInfo]):
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@MULTIMODAL_REGISTRY.register_processor(
|
||||||
|
KeyeVL1_5MultiModalProcessor,
|
||||||
|
info=KeyeVL1_5ProcessingInfo,
|
||||||
|
dummy_inputs=KeyeVL1_5DummyInputsBuilder,
|
||||||
|
)
|
||||||
|
class KeyeVL1_5ForConditionalGeneration(BaseKeyeModule, SupportsMultiModal,
|
||||||
|
SupportsLoRA, SupportsPP):
|
||||||
|
|
||||||
|
def _build_projector(self,
|
||||||
|
text_config: PretrainedConfig,
|
||||||
|
vision_config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> nn.Module:
|
||||||
|
return KeyeVL1_5Projector(text_config, vision_config, quant_config,
|
||||||
|
prefix)
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||||
|
self.merge_size = config.vision_config.spatial_merge_size
|
||||||
|
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def _validate_and_reshape_mm_tensor(self, mm_input: NestedTensors,
|
||||||
|
expected_dim: int, name: str):
|
||||||
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
|
raise ValueError(f"Incorrect type of {name}. "
|
||||||
|
f"Got type: {type(mm_input)}")
|
||||||
|
if isinstance(mm_input, torch.Tensor):
|
||||||
|
if mm_input.ndim == expected_dim:
|
||||||
|
return mm_input
|
||||||
|
elif mm_input.ndim == expected_dim + 1:
|
||||||
|
return torch.concat(list(mm_input))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{name} should be {expected_dim}D or "
|
||||||
|
f"batched {expected_dim}D tensor."
|
||||||
|
f"Got ndim: {mm_input.ndim} (shape={mm_input.shape})")
|
||||||
|
else:
|
||||||
|
return torch.concat(list(mm_input))
|
||||||
|
|
||||||
|
def _parse_and_validate_image_input(
|
||||||
|
self, **kwargs: object) -> Optional[KeyeVL1_5ImageInputs]:
|
||||||
|
pixel_values = kwargs.pop("pixel_values", None)
|
||||||
|
image_embeds = kwargs.pop("image_embeds", None)
|
||||||
|
image_grid_thw = kwargs.pop("image_grid_thw", None)
|
||||||
|
|
||||||
|
if pixel_values is None and image_embeds is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = self._validate_and_reshape_mm_tensor(
|
||||||
|
pixel_values, expected_dim=4, name="image pixel values")
|
||||||
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_grid_thw, expected_dim=2, name="image grid_thw")
|
||||||
|
|
||||||
|
return KeyeVL1_5ImagePixelInputs(
|
||||||
|
type="pixel_values",
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
|
||||||
|
if image_embeds is not None:
|
||||||
|
image_embeds = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_embeds, expected_dim=2, name="image embeds")
|
||||||
|
image_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
image_grid_thw, expected_dim=2, name="image grid_thw")
|
||||||
|
|
||||||
|
return KeyeVL1_5ImageEmbeddingInputs(
|
||||||
|
type="image_embeds",
|
||||||
|
image_embeds=image_embeds,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_and_validate_video_input(
|
||||||
|
self, **kwargs: object) -> Optional[KeyeVL1_5VideoInputs]:
|
||||||
|
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
|
||||||
|
video_embeds = kwargs.pop("video_embeds", None)
|
||||||
|
video_grid_thw = kwargs.pop("video_grid_thw", None)
|
||||||
|
num_frames = kwargs.pop("num_frames", None)
|
||||||
|
|
||||||
|
if pixel_values_videos is None and video_embeds is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = self._validate_and_reshape_mm_tensor(
|
||||||
|
pixel_values_videos,
|
||||||
|
expected_dim=4,
|
||||||
|
name="video pixel values",
|
||||||
|
)
|
||||||
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_grid_thw, expected_dim=2, name="video grid_thw")
|
||||||
|
|
||||||
|
num_frames = self._validate_and_reshape_mm_tensor(
|
||||||
|
num_frames, expected_dim=1, name="video num frames")
|
||||||
|
|
||||||
|
return KeyeVL1_5VideoPixelInputs(
|
||||||
|
type="pixel_values_videos",
|
||||||
|
pixel_values_videos=pixel_values_videos,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
num_frames=num_frames)
|
||||||
|
|
||||||
|
if video_embeds is not None:
|
||||||
|
video_embeds = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_embeds, expected_dim=2, name="video embeds")
|
||||||
|
video_grid_thw = self._validate_and_reshape_mm_tensor(
|
||||||
|
video_grid_thw, expected_dim=2, name="video grid_thw")
|
||||||
|
|
||||||
|
return KeyeVL1_5VideoEmbeddingInputs(type="video_embeds",
|
||||||
|
video_embeds=video_embeds,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
num_frames=num_frames)
|
||||||
|
|
||||||
|
def _process_video_input(
|
||||||
|
self,
|
||||||
|
video_input: KeyeVL1_5VideoInputs) -> tuple[torch.Tensor, ...]:
|
||||||
|
video_type = video_input["type"]
|
||||||
|
video_grid_thw = split_thw(video_input["video_grid_thw"])
|
||||||
|
pixel_values_videos = video_input.get("pixel_values_videos", None)
|
||||||
|
|
||||||
|
video_embeds = self._process_video_embeds(video_type, video_grid_thw,
|
||||||
|
pixel_values_videos)
|
||||||
|
video_embeds = torch.concat(video_embeds, dim=0)
|
||||||
|
|
||||||
|
num_frames = video_input["num_frames"].clone().tolist()
|
||||||
|
|
||||||
|
num_patches = get_num_patches(video_grid_thw, num_frames).tolist()
|
||||||
|
|
||||||
|
patch_cu_seqlens = torch.cumsum(
|
||||||
|
torch.tensor([0] + num_patches).detach().clone(), dim=-1)
|
||||||
|
patch_cu_seqlens = torch.div(patch_cu_seqlens,
|
||||||
|
self.merge_size**2,
|
||||||
|
rounding_mode="floor")
|
||||||
|
|
||||||
|
new_video_embeds = []
|
||||||
|
for idx in range(patch_cu_seqlens.shape[0] - 1):
|
||||||
|
start = patch_cu_seqlens[idx]
|
||||||
|
end = patch_cu_seqlens[idx + 1]
|
||||||
|
new_video_embeds.append(video_embeds[start:end])
|
||||||
|
return tuple(new_video_embeds)
|
||||||
@ -227,6 +227,7 @@ _MULTIMODAL_MODELS = {
|
|||||||
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
"Idefics3ForConditionalGeneration":("idefics3","Idefics3ForConditionalGeneration"),
|
||||||
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
"SmolVLMForConditionalGeneration": ("smolvlm","SmolVLMForConditionalGeneration"), # noqa: E501
|
||||||
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
"KeyeForConditionalGeneration": ("keye", "KeyeForConditionalGeneration"),
|
||||||
|
"KeyeVL1_5ForConditionalGeneration": ("keye_vl1_5", "KeyeVL1_5ForConditionalGeneration"), # noqa: E501
|
||||||
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
|
||||||
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
|
||||||
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
"Llama_Nemotron_Nano_VL": ("nemotron_vl", "LlamaNemotronVLChatModel"),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user