mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-21 23:11:19 +08:00
[Bugfix] [Model] Missing MRoPE function definition from KeyeForConditionalGeneration (#27895)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
879a06579e
commit
e2347dbf58
86
tests/models/multimodal/generation/test_keye.py
Normal file
86
tests/models/multimodal/generation/test_keye.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import asdict
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from PIL.Image import Image
|
||||||
|
from transformers import AutoProcessor
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs, SamplingParams
|
||||||
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
|
|
||||||
|
MODEL_NAME = "Kwai-Keye/Keye-VL-8B-Preview"
|
||||||
|
|
||||||
|
QUESTION = "What is the content of each image?"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRequestData(NamedTuple):
|
||||||
|
engine_args: EngineArgs
|
||||||
|
prompt: str
|
||||||
|
image_data: list[Image]
|
||||||
|
stop_token_ids: list[int] | None = None
|
||||||
|
chat_template: str | None = None
|
||||||
|
sampling_params: SamplingParams | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.core_model
|
||||||
|
@pytest.mark.parametrize("question", [QUESTION])
|
||||||
|
def test_keye_vl(
|
||||||
|
image_assets,
|
||||||
|
question: str,
|
||||||
|
):
|
||||||
|
images = [asset.pil_image for asset in image_assets]
|
||||||
|
|
||||||
|
image_urls = [
|
||||||
|
f"data:image/jpeg;base64,{encode_image_base64(image)}" for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
engine_args = EngineArgs(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=8192,
|
||||||
|
max_num_seqs=5,
|
||||||
|
limit_mm_per_prompt={"image": len(image_urls)},
|
||||||
|
)
|
||||||
|
|
||||||
|
placeholders = [{"type": "image", "image": url} for url in image_urls]
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
*placeholders,
|
||||||
|
{"type": "text", "text": question},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
||||||
|
|
||||||
|
prompt = processor.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
|
||||||
|
engine_args = asdict(engine_args) | {"seed": 42}
|
||||||
|
llm = LLM(**engine_args)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0, max_tokens=256, stop_token_ids=None
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
"multi_modal_data": {"image": images},
|
||||||
|
},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
|
for o in outputs:
|
||||||
|
generated_text = o.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
assert len(generated_text) > 10, (
|
||||||
|
f"Generated text is too short: {generated_text}"
|
||||||
|
)
|
||||||
|
print("-" * 50)
|
||||||
@ -17,7 +17,9 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
|
|||||||
from transformers.utils import torch_int
|
from transformers.utils import torch_int
|
||||||
|
|
||||||
from vllm.attention.backends.registry import _Backend
|
from vllm.attention.backends.registry import _Backend
|
||||||
from vllm.attention.layer import check_upstream_fa_availability
|
from vllm.attention.layer import (
|
||||||
|
maybe_get_vit_flash_attn_backend,
|
||||||
|
)
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
@ -56,12 +58,14 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
from .interfaces import (
|
from .interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
SupportsLoRA,
|
SupportsLoRA,
|
||||||
|
SupportsMRoPE,
|
||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
@ -337,7 +341,10 @@ def apply_rotary_pos_emb_flashatt(
|
|||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
if current_platform.is_cuda():
|
||||||
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
elif current_platform.is_rocm():
|
||||||
|
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||||
|
|
||||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||||
@ -398,18 +405,28 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
attn_backend_override=attn_backend_override,
|
attn_backend_override=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.use_upstream_fa = False
|
self.attn_backend, self.flash_attn_varlen_func = (
|
||||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
maybe_get_vit_flash_attn_backend(
|
||||||
torch.get_default_dtype()
|
self.attn_backend,
|
||||||
):
|
use_upstream_fa=False,
|
||||||
self.attn_backend = _Backend.FLASH_ATTN
|
attn_backend_override=attn_backend_override,
|
||||||
self.use_upstream_fa = True
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
|
if self.attn_backend not in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.XFORMERS,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
}:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Keye-VL does not support {self.attn_backend} backend now."
|
f"Keye-VL does not support {self.attn_backend} backend now."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.is_flash_attn_backend = self.attn_backend in {
|
||||||
|
_Backend.FLASH_ATTN,
|
||||||
|
_Backend.ROCM_AITER_FA,
|
||||||
|
}
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -457,15 +474,10 @@ class KeyeSiglipAttention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
if self.is_flash_attn_backend:
|
||||||
if self.use_upstream_fa:
|
|
||||||
from flash_attn import flash_attn_varlen_func
|
|
||||||
else:
|
|
||||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
|
||||||
|
|
||||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||||
|
|
||||||
output = flash_attn_varlen_func(
|
output = self.flash_attn_varlen_func(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
@ -1542,7 +1554,7 @@ class BaseKeyeModule(nn.Module):
|
|||||||
dummy_inputs=KeyeDummyInputsBuilder,
|
dummy_inputs=KeyeDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class KeyeForConditionalGeneration(
|
class KeyeForConditionalGeneration(
|
||||||
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP
|
BaseKeyeModule, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
|
||||||
):
|
):
|
||||||
def _build_projector(
|
def _build_projector(
|
||||||
self,
|
self,
|
||||||
@ -1611,3 +1623,142 @@ class KeyeForConditionalGeneration(
|
|||||||
return tuple(
|
return tuple(
|
||||||
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
self._process_video_embeds(video_type, video_grid_thw, pixel_values_videos)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_mrope_input_positions(
|
||||||
|
self,
|
||||||
|
input_tokens: list[int],
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
image_grid_thw: list[list[int]] | torch.Tensor,
|
||||||
|
video_grid_thw: list[list[int]] | torch.Tensor,
|
||||||
|
context_len: int = 0,
|
||||||
|
seq_len: int | None = None,
|
||||||
|
second_per_grid_ts: list[float] | None = None,
|
||||||
|
audio_feature_lengths: torch.Tensor | None = None,
|
||||||
|
use_audio_in_video: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, int]:
|
||||||
|
if isinstance(video_grid_thw, list) and len(video_grid_thw) > 0:
|
||||||
|
video_grid_thw = video_grid_thw[0]
|
||||||
|
"""Get mrope input positions and delta value (Keye series)."""
|
||||||
|
|
||||||
|
def split_thw(grid_thw: torch.Tensor | list[int]) -> list[list[int]]:
|
||||||
|
"""
|
||||||
|
Split grid_thw along the t dimension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grid_thw: shape [N, 3] tensor or nested list of [t, h, w].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of [1, h, w] rows, repeated t times for each original row.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(grid_thw, list):
|
||||||
|
grid_thw = torch.tensor(grid_thw, dtype=torch.long)
|
||||||
|
|
||||||
|
if grid_thw.numel() == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
t, hw = grid_thw[:, 0], grid_thw[:, 1:]
|
||||||
|
ones = torch.ones_like(hw[:, :1]) # [N,1]
|
||||||
|
out = torch.cat([ones, hw], dim=1).repeat_interleave(t, dim=0)
|
||||||
|
return out.tolist()
|
||||||
|
|
||||||
|
video_grid_thw = split_thw(video_grid_thw)
|
||||||
|
|
||||||
|
image_token_id = hf_config.image_token_id
|
||||||
|
video_token_id = hf_config.video_token_id
|
||||||
|
spatial_merge_size = hf_config.vision_config.spatial_merge_size
|
||||||
|
|
||||||
|
image_nums = len(image_grid_thw)
|
||||||
|
frame_nums = len(video_grid_thw)
|
||||||
|
llm_pos_ids_list: list = []
|
||||||
|
|
||||||
|
st = 0
|
||||||
|
remain_images, remain_frames = image_nums, frame_nums
|
||||||
|
|
||||||
|
image_index, video_index = 0, 0
|
||||||
|
for _ in range(image_nums + frame_nums):
|
||||||
|
if remain_images > 0:
|
||||||
|
try:
|
||||||
|
ed_image = input_tokens.index(image_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_image = len(input_tokens) + 1
|
||||||
|
if remain_frames > 0:
|
||||||
|
try:
|
||||||
|
ed_video = input_tokens.index(video_token_id, st)
|
||||||
|
except ValueError:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
else:
|
||||||
|
ed_video = len(input_tokens) + 1
|
||||||
|
|
||||||
|
if ed_image < ed_video:
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[image_index][0],
|
||||||
|
image_grid_thw[image_index][1],
|
||||||
|
image_grid_thw[image_index][2],
|
||||||
|
)
|
||||||
|
image_index += 1
|
||||||
|
remain_images -= 1
|
||||||
|
ed = ed_image
|
||||||
|
else:
|
||||||
|
t, h, w = (
|
||||||
|
video_grid_thw[video_index][0],
|
||||||
|
video_grid_thw[video_index][1],
|
||||||
|
video_grid_thw[video_index][2],
|
||||||
|
)
|
||||||
|
video_index += 1
|
||||||
|
remain_frames -= 1
|
||||||
|
ed = ed_video
|
||||||
|
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t,
|
||||||
|
h // spatial_merge_size,
|
||||||
|
w // spatial_merge_size,
|
||||||
|
)
|
||||||
|
text_len = ed - st
|
||||||
|
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
t_index = (
|
||||||
|
(
|
||||||
|
torch.arange(llm_grid_t)
|
||||||
|
.view(-1, 1)
|
||||||
|
.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
)
|
||||||
|
.long()
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
|
||||||
|
h_index = (
|
||||||
|
torch.arange(llm_grid_h)
|
||||||
|
.view(1, -1, 1)
|
||||||
|
.expand(llm_grid_t, -1, llm_grid_w)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
w_index = (
|
||||||
|
torch.arange(llm_grid_w)
|
||||||
|
.view(1, 1, -1)
|
||||||
|
.expand(llm_grid_t, llm_grid_h, -1)
|
||||||
|
.flatten()
|
||||||
|
)
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
|
||||||
|
)
|
||||||
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
||||||
|
|
||||||
|
if st < len(input_tokens):
|
||||||
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
||||||
|
text_len = len(input_tokens) - st
|
||||||
|
llm_pos_ids_list.append(
|
||||||
|
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||||
|
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
|
||||||
|
llm_positions = llm_positions[:, context_len:seq_len]
|
||||||
|
|
||||||
|
return llm_positions, mrope_position_delta
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user