[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration (#27895)

Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
TJian 2025-10-31 22:45:23 -07:00 committed by GitHub
parent 879a06579e
commit e2347dbf58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 254 additions and 17 deletions

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import asdict
from typing import NamedTuple
import pytest
from PIL.Image import Image
from transformers import AutoProcessor
from vllm import LLM, EngineArgs, SamplingParams
from vllm.multimodal.utils import encode_image_base64
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
QUESTION = "What is the content of each image?"
class ModelRequestData(NamedTuple):
engine_args: EngineArgs
prompt: str
image_data: list[Image]
stop_token_ids: list[int] | None = None
chat_template: str | None = None
sampling_params: SamplingParams | None = None
@pytest.mark.core_model
@pytest.mark.parametrize("question", [QUESTION])
def test_keye_vl(
image_assets,
question: str,
):
images = [asset.pil_image for asset in image_assets]
image_urls = [
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
]
engine_args = EngineArgs(
model=MODEL_NAME,
trust_remote_code=True,
max_model_len=8192,
max_num_seqs=5,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [
{
"role": "user",
"content": [
*placeholders,
{"type": "text", "text": question},
],
},
]
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
engine_args = asdict(engine_args) | {"seed": 42}
llm = LLM(**engine_args)
sampling_params = SamplingParams(
temperature=0.0, max_tokens=256, stop_token_ids=None
)
outputs = llm.generate(
{
"prompt": prompt,
"multi_modal_data": {"image": images},
},
sampling_params=sampling_params,
)
print("-" * 50)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
assert len(generated_text) > 10, (
f"Generated text is too short: {generated_text}"
)
print("-" * 50)

View File

@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.utils import torch_int
from vllm.attention.backends.registry import _Backend
from vllm.attention.layer import check_upstream_fa_availability
from vllm.attention.layer import (
maybe_get_vit_flash_attn_backend,
)
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMRoPE,
SupportsMultiModal,
SupportsPP,
)
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
self.attn_backend, self.flash_attn_varlen_func = (
maybe_get_vit_flash_attn_backend(
self.attn_backend,
use_upstream_fa=False,
attn_backend_override=attn_backend_override,
)
)
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.XFORMERS,
_Backend.ROCM_AITER_FA,
}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now."
)
self.is_flash_attn_backend = self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.ROCM_AITER_FA,
}
def forward(
self,
hidden_states: torch.Tensor,
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
self.head_dim,
)
if self.attn_backend == _Backend.FLASH_ATTN:
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
if self.is_flash_attn_backend:
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
output = self.flash_attn_varlen_func(
q,
k,
v,
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
dummy_inputs=KeyeDummyInputsBuilder,
)
class KeyeForConditionalGeneration(
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
def _build_projector(
self,
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
return tuple(
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
)
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
video_grid_thw = video_grid_thw[0]
"""Get mrope input positions and delta value (Keye series)."""
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
"""
Split grid_thw along the t dimension.
Args:
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
Returns:
List of [1, h, w] rows, repeated t times for each original row.
"""
if isinstance(grid_thw, list):
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
if grid_thw.numel() == 0:
return []
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
ones = torch.ones_like(hw[:, :1]) # [N,1]
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
return out.tolist()
video_grid_thw = split_thw(video_grid_thw)
image_token_id = hf_config.image_token_id
video_token_id = hf_config.video_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
image_nums = len(image_grid_thw)
frame_nums = len(video_grid_thw)
llm_pos_ids_list: list = []
st = 0
remain_images, remain_frames = image_nums, frame_nums
image_index, video_index = 0, 0
for _ in range(image_nums + frame_nums):
if remain_images > 0:
try:
ed_image = input_tokens.index(image_token_id, st)
except ValueError:
ed_image = len(input_tokens) + 1
else:
ed_image = len(input_tokens) + 1
if remain_frames > 0:
try:
ed_video = input_tokens.index(video_token_id, st)
except ValueError:
ed_video = len(input_tokens) + 1
else:
ed_video = len(input_tokens) + 1
if ed_image < ed_video:
t, h, w = (
image_grid_thw[image_index][0],
image_grid_thw[image_index][1],
image_grid_thw[image_index][2],
)
image_index += 1
remain_images -= 1
ed = ed_image
else:
t, h, w = (
video_grid_thw[video_index][0],
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
video_index += 1
remain_frames -= 1
ed = ed_video
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)
text_len = ed - st
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
t_index = (
(
torch.arange(llm_grid_t)
.view(-1, 1)
.expand(-1, llm_grid_h * llm_grid_w)
)
.long()
.flatten()
)
h_index = (
torch.arange(llm_grid_h)
.view(1, -1, 1)
.expand(llm_grid_t, -1, llm_grid_w)
.flatten()
)
w_index = (
torch.arange(llm_grid_w)
.view(1, 1, -1)
.expand(llm_grid_t, llm_grid_h, -1)
.flatten()
)
llm_pos_ids_list.append(
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
)
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
if st < len(input_tokens):
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
text_len = len(input_tokens) - st
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
)
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
llm_positions = llm_positions[:, context_len:seq_len]
return llm_positions, mrope_position_delta