mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-25 04:45:01 +08:00
[Bugfix] Fix M-RoPE position calculation when chunked prefill is enabled (#10388)
Signed-off-by: imkero <kerorek@outlook.com>
This commit is contained in:
parent
b98d89efd4
commit
361c29e174
@ -18,6 +18,7 @@ target_dtype = "half"
|
||||
|
||||
IMAGE_PLACEHOLDER = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
MODEL_HIDDEN_SIZE = 1536
|
||||
|
||||
|
||||
def qwen2_vl_chat_template(*query):
|
||||
@ -230,7 +231,7 @@ def batch_make_video_embeddings(
|
||||
return result
|
||||
|
||||
|
||||
def run_test(
|
||||
def run_embedding_input_test(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
|
||||
model: str,
|
||||
@ -326,7 +327,7 @@ def test_qwen2_vl_image_embeddings_input(vllm_runner, image_assets, model,
|
||||
[],
|
||||
) for image, prompt in zip(images, IMAGE_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
run_embedding_input_test(
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
@ -371,7 +372,7 @@ def test_qwen2_vl_multiple_image_embeddings_input(vllm_runner, image_assets,
|
||||
[],
|
||||
)]
|
||||
|
||||
run_test(
|
||||
run_embedding_input_test(
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
@ -416,7 +417,134 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model,
|
||||
[rescale_video_size(video, factor) for factor in size_factors],
|
||||
) for video, prompt in zip(sampled_vids, VIDEO_PROMPTS)]
|
||||
|
||||
run_test(
|
||||
run_embedding_input_test(
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
mm_limit=1,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
def run_chunked_prefill_test(
|
||||
vllm_runner: Type[VllmRunner],
|
||||
inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]],
|
||||
model: str,
|
||||
*,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
mm_limit: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Compare inference result between
|
||||
chunked prefill disabled and chunked prefill enabled
|
||||
"""
|
||||
|
||||
# NOTE:
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
task="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=4,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend
|
||||
) as vllm_model:
|
||||
|
||||
outputs_per_case = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images or None,
|
||||
videos=videos or None)
|
||||
for prompts, images, videos in inputs
|
||||
]
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
task="generate",
|
||||
max_model_len=4000,
|
||||
max_num_seqs=4,
|
||||
dtype=dtype,
|
||||
limit_mm_per_prompt={
|
||||
"image": mm_limit,
|
||||
"video": mm_limit
|
||||
},
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enable_chunked_prefill=True,
|
||||
# should be small enough to ensure prefilling is chunked
|
||||
max_num_batched_tokens=32,
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 16 * 28 * 28,
|
||||
}) as vllm_model_chunked:
|
||||
outputs_per_case_chunked = [
|
||||
vllm_model_chunked.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images or None,
|
||||
videos=videos or None) for prompts, images, videos in inputs
|
||||
]
|
||||
|
||||
for outputs, \
|
||||
outputs_chunked \
|
||||
in zip(outputs_per_case,
|
||||
outputs_per_case_chunked):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=outputs,
|
||||
outputs_1_lst=outputs_chunked,
|
||||
name_0="non_chunked",
|
||||
name_1="chunked",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [1])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts,
|
||||
model: str, dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int) -> None:
|
||||
"""
|
||||
Test Qwen2-VL's chunked prefill with M-RoPE
|
||||
"""
|
||||
prompts = [
|
||||
qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt)
|
||||
for prompt in example_prompts[:1]
|
||||
]
|
||||
|
||||
# 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs,
|
||||
# so an image is included in the inputs
|
||||
# 2. however, Qwen2-VL currently won't work properly
|
||||
# when chunked prefill is enabled and there are some multi-modal inputs,
|
||||
# here use a hacky way: provide a **zero-length** image to make it happy
|
||||
#
|
||||
# and finally we achieved:
|
||||
# (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests
|
||||
zero_len_image = {
|
||||
"image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)),
|
||||
"image_grid_thw": torch.tensor([[0, 0, 0]])
|
||||
}
|
||||
images = [zero_len_image] * len(prompts)
|
||||
|
||||
inputs_per_case: List[Tuple[List[str], PromptImageInput,
|
||||
PromptVideoInput]] = [
|
||||
(prompts, images, []),
|
||||
]
|
||||
|
||||
run_chunked_prefill_test(
|
||||
vllm_runner,
|
||||
inputs_per_case,
|
||||
model,
|
||||
|
||||
@ -847,6 +847,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
vision_end_token_id: int,
|
||||
spatial_merge_size: int,
|
||||
context_len: int = 0,
|
||||
seq_len: Optional[int] = None,
|
||||
) -> Tuple[List[List[int]], int]:
|
||||
"""Get mrope input positions and delta value."""
|
||||
|
||||
@ -921,7 +922,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
||||
|
||||
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
||||
llm_positions = llm_positions[:, context_len:]
|
||||
llm_positions = llm_positions[:, context_len:seq_len]
|
||||
mrope_position_delta = (llm_positions.max() + 1 -
|
||||
len(input_tokens)).item()
|
||||
|
||||
|
||||
@ -700,6 +700,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
spatial_merge_size=hf_config.vision_config.
|
||||
spatial_merge_size,
|
||||
context_len=inter_data.context_lens[seq_idx],
|
||||
seq_len=inter_data.seq_lens[seq_idx],
|
||||
)
|
||||
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user