mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 03:44:58 +08:00
[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration (#27895)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
879a06579e
commit
e2347dbf58
86
tests/models/multimodal/generation/test_keye.py
Normal file
86
tests/models/multimodal/generation/test_keye.py
Normal file
@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
from PIL.Image import Image
|
||||
from transformers import AutoProcessor
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.multimodal.utils import encode_image_base64
|
||||
|
||||
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||
|
||||
QUESTION = "What is the content of each image?"
|
||||
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
image_data: list[Image]
|
||||
stop_token_ids: list[int] | None = None
|
||||
chat_template: str | None = None
|
||||
sampling_params: SamplingParams | None = None
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("question", [QUESTION])
|
||||
def test_keye_vl(
|
||||
image_assets,
|
||||
question: str,
|
||||
):
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
image_urls = [
|
||||
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
|
||||
]
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
trust_remote_code=True,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=5,
|
||||
limit_mm_per_prompt={"image": len(image_urls)},
|
||||
)
|
||||
|
||||
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*placeholders,
|
||||
{"type": "text", "text": question},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||
|
||||
prompt = processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
engine_args = asdict(engine_args) | {"seed": 42}
|
||||
llm = LLM(**engine_args)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.0, max_tokens=256, stop_token_ids=None
|
||||
)
|
||||
|
||||
outputs = llm.generate(
|
||||
{
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": {"image": images},
|
||||
},
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
|
||||
print("-" * 50)
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
assert len(generated_text) > 10, (
|
||||
f"Generated text is too short: {generated_text}"
|
||||
)
|
||||
print("-" * 50)
|
||||
@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.attention.layer import (
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMRoPE,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
elif current_platform.is_rocm():
|
||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
self.attn_backend, self.flash_attn_varlen_func = (
|
||||
maybe_get_vit_flash_attn_backend(
|
||||
self.attn_backend,
|
||||
use_upstream_fa=False,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
|
||||
self.head_dim,
|
||||
)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
output = self.flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
|
||||
dummy_inputs=KeyeDummyInputsBuilder,
|
||||
)
|
||||
class KeyeForConditionalGeneration(
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||
):
|
||||
def _build_projector(
|
||||
self,
|
||||
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
|
||||
return tuple(
|
||||
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
||||
)
|
||||
|
||||
def get_mrope_input_positions(
|
||||
self,
|
||||
input_tokens: list[int],
|
||||
hf_config: PretrainedConfig,
|
||||
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||
context_len: int = 0,
|
||||
seq_len: int | None = None,
|
||||
second_per_grid_ts: list[float] | None = None,
|
||||
audio_feature_lengths: torch.Tensor | None = None,
|
||||
use_audio_in_video: bool = False,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||
video_grid_thw = video_grid_thw[0]
|
||||
"""Get mrope input positions and delta value (Keye series)."""
|
||||
|
||||
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
|
||||
"""
|
||||
Split grid_thw along the t dimension.
|
||||
|
||||
Args:
|
||||
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||
|
||||
Returns:
|
||||
List of [1, h, w] rows, repeated t times for each original row.
|
||||
"""
|
||||
|
||||
if isinstance(grid_thw, list):
|
||||
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||
|
||||
if grid_thw.numel() == 0:
|
||||
return []
|
||||
|
||||
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||
return out.tolist()
|
||||
|
||||
video_grid_thw = split_thw(video_grid_thw)
|
||||
|
||||
image_token_id = hf_config.image_token_id
|
||||
video_token_id = hf_config.video_token_id
|
||||
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||
|
||||
image_nums = len(image_grid_thw)
|
||||
frame_nums = len(video_grid_thw)
|
||||
llm_pos_ids_list: list = []
|
||||
|
||||
st = 0
|
||||
remain_images, remain_frames = image_nums, frame_nums
|
||||
|
||||
image_index, video_index = 0, 0
|
||||
for _ in range(image_nums + frame_nums):
|
||||
if remain_images > 0:
|
||||
try:
|
||||
ed_image = input_tokens.index(image_token_id, st)
|
||||
except ValueError:
|
||||
ed_image = len(input_tokens) + 1
|
||||
else:
|
||||
ed_image = len(input_tokens) + 1
|
||||
if remain_frames > 0:
|
||||
try:
|
||||
ed_video = input_tokens.index(video_token_id, st)
|
||||
except ValueError:
|
||||
ed_video = len(input_tokens) + 1
|
||||
else:
|
||||
ed_video = len(input_tokens) + 1
|
||||
|
||||
if ed_image < ed_video:
|
||||
t, h, w = (
|
||||
image_grid_thw[image_index][0],
|
||||
image_grid_thw[image_index][1],
|
||||
image_grid_thw[image_index][2],
|
||||
)
|
||||
image_index += 1
|
||||
remain_images -= 1
|
||||
ed = ed_image
|
||||
else:
|
||||
t, h, w = (
|
||||
video_grid_thw[video_index][0],
|
||||
video_grid_thw[video_index][1],
|
||||
video_grid_thw[video_index][2],
|
||||
)
|
||||
video_index += 1
|
||||
remain_frames -= 1
|
||||
ed = ed_video
|
||||
|
||||
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||
t,
|
||||
h // spatial_merge_size,
|
||||
w // spatial_merge_size,
|
||||
)
|
||||
text_len = ed - st
|
||||
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
t_index = (
|
||||
(
|
||||
torch.arange(llm_grid_t)
|
||||
.view(-1, 1)
|
||||
.expand(-1, llm_grid_h * llm_grid_w)
|
||||
)
|
||||
.long()
|
||||
.flatten()
|
||||
)
|
||||
|
||||
h_index = (
|
||||
torch.arange(llm_grid_h)
|
||||
.view(1, -1, 1)
|
||||
.expand(llm_grid_t, -1, llm_grid_w)
|
||||
.flatten()
|
||||
)
|
||||
w_index = (
|
||||
torch.arange(llm_grid_w)
|
||||
.view(1, 1, -1)
|
||||
.expand(llm_grid_t, llm_grid_h, -1)
|
||||
.flatten()
|
||||
)
|
||||
llm_pos_ids_list.append(
|
||||
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||
)
|
||||
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||
|
||||
if st < len(input_tokens):
|
||||
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||
text_len = len(input_tokens) - st
|
||||
llm_pos_ids_list.append(
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||
)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
|
||||
return llm_positions, mrope_position_delta
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user