mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:04:53 +08:00
[Bugfix] handle alignment of arguments in convert_sparse_cross_attention_mask_to_dense (#12347)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Signed-off-by: Wallas Santos <wallashss@ibm.com> Co-authored-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
parent
ef001d98ef
commit
036ca94c25
@ -1,11 +1,15 @@
|
||||
from typing import List, Optional, Tuple, Type, overload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer,
|
||||
BatchEncoding)
|
||||
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.model_executor.models.mllama import (MLLAMA_IMAGE_TOKEN_ID,
|
||||
MllamaForConditionalGeneration)
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.sequence import SampleLogprobs
|
||||
|
||||
@ -33,6 +37,29 @@ models = [
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
]
|
||||
|
||||
# Indices for inputs
|
||||
TEXT_ONLY = '0'
|
||||
IMAGE_AT_BEG = '1'
|
||||
IMAGE_AT_MIDDLE = '2'
|
||||
TWO_IMAGES = '3'
|
||||
|
||||
# Input tokenized
|
||||
prompt_data = {
|
||||
# Tell me a story
|
||||
TEXT_ONLY: [41551, 757, 264, 3446],
|
||||
# <|image|> What's the content of this image
|
||||
IMAGE_AT_BEG:
|
||||
[MLLAMA_IMAGE_TOKEN_ID, 3639, 596, 279, 2262, 315, 420, 2217, 220],
|
||||
# Hello <|image|>What' the content of this image
|
||||
IMAGE_AT_MIDDLE:
|
||||
[9906, 220, MLLAMA_IMAGE_TOKEN_ID, 3923, 6, 279, 2262, 315, 420, 2217],
|
||||
#<|image|>Is there a duck in this image?<|image|>What's the animal in this image? # noqa: E501
|
||||
TWO_IMAGES: [
|
||||
MLLAMA_IMAGE_TOKEN_ID, 3957, 1070, 264, 37085, 304, 420, 2217, 30,
|
||||
MLLAMA_IMAGE_TOKEN_ID, 3923, 596, 279, 10065, 304, 420, 2217, 30
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def vllm_to_hf_output(vllm_output: Tuple[List[int], str,
|
||||
Optional[SampleLogprobs]],
|
||||
@ -365,3 +392,184 @@ def test_models_interleaved_images(hf_runner, vllm_runner, image_assets, model,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
||||
|
||||
|
||||
@large_gpu_test(min_gb=48)
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
|
||||
def test_regression(vllm_runner, image_assets, model, dtype, max_tokens,
|
||||
num_logprobs, attn_backend: _Backend) -> None:
|
||||
|
||||
stop_sign = image_assets[0].pil_image
|
||||
|
||||
with global_force_attn_backend_context_manager(attn_backend), vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
tensor_parallel_size=1,
|
||||
enforce_eager=True,
|
||||
limit_mm_per_prompt={"image":
|
||||
_LIMIT_IMAGE_PER_PROMPT}) as vllm_model:
|
||||
|
||||
# Regression tests for https://github.com/vllm-project/vllm/issues/10648
|
||||
|
||||
# Number of image tags is greater than the number of images provided
|
||||
prompt = "<|begin_of_text|><|image|><|image|> Compare the two images" # noqa: E501
|
||||
image = stop_sign
|
||||
with pytest.raises(ValueError):
|
||||
vllm_model.generate_greedy_logprobs([prompt],
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
images=[image])
|
||||
|
||||
# Batch of a text-only and image request that requires cross-attention
|
||||
prompts = [
|
||||
"What is the capital of spain?",
|
||||
"Text before the image...<|image|>What is in the image?", # noqa: E501
|
||||
]
|
||||
images = [
|
||||
None,
|
||||
[stop_sign],
|
||||
]
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
images=images)
|
||||
|
||||
# Test the reverse order too for good measure
|
||||
prompts = [
|
||||
"<|begin_of_text|>Text before the image...<|image|>What is in the image?", # noqa: E501
|
||||
"<|begin_of_text|>Hello!",
|
||||
]
|
||||
images = [
|
||||
[stop_sign],
|
||||
None,
|
||||
]
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs,
|
||||
images=images)
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize(
|
||||
"input_indices_and_output",
|
||||
# inputs, (cross_attention_mask, kv_range_for_decode)
|
||||
[([TEXT_ONLY], (None, None)), ([IMAGE_AT_BEG], (None, None)),
|
||||
([TEXT_ONLY, IMAGE_AT_BEG], (None, None)),
|
||||
([IMAGE_AT_MIDDLE], ((10, 12), [[0, 6]])),
|
||||
([TEXT_ONLY, IMAGE_AT_MIDDLE], ((14, 12), [[0, 6]])),
|
||||
([TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
|
||||
((23, 24), [[0, 6], [6, 12]])),
|
||||
([IMAGE_AT_MIDDLE, TEXT_ONLY], ((14, 12), [[0, 6]])),
|
||||
([TWO_IMAGES], ((18, 12), [[6, 12]])),
|
||||
([TEXT_ONLY, TWO_IMAGES], ((22, 12), [[6, 12]]))])
|
||||
def test_get_cross_attention_mask(input_indices_and_output) -> None:
|
||||
|
||||
input_indices, expected_output = input_indices_and_output
|
||||
|
||||
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
|
||||
num_tiles = [[2, 2] if i != TEXT_ONLY else [] for i in input_indices
|
||||
if i != TEXT_ONLY]
|
||||
input = torch.cat(sequences)
|
||||
|
||||
seq_lens = [len(s) for s in sequences]
|
||||
|
||||
attn_data = FlashAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
# Dummy values
|
||||
enable_kv_scales_calculation=False,
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=0,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens_tensor=0,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=None,
|
||||
block_tables=None,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
dummy: dict[str, str] = {}
|
||||
|
||||
cross_attention_mask, kv_range_for_decode = MllamaForConditionalGeneration\
|
||||
.get_cross_attention_mask(dummy,
|
||||
input,
|
||||
attn_data,
|
||||
num_tiles=num_tiles,
|
||||
num_tokens_per_tile=3,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
expected_cross_attention_mask, expected_kv_range_for_decode = \
|
||||
expected_output
|
||||
|
||||
assert kv_range_for_decode == expected_kv_range_for_decode
|
||||
if expected_cross_attention_mask is not None:
|
||||
assert cross_attention_mask is not None
|
||||
assert cross_attention_mask.shape == expected_cross_attention_mask
|
||||
else:
|
||||
assert cross_attention_mask is None
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.parametrize(
|
||||
"input_indices",
|
||||
[[TEXT_ONLY], [IMAGE_AT_BEG], [TEXT_ONLY, IMAGE_AT_BEG], [IMAGE_AT_MIDDLE],
|
||||
[TEXT_ONLY, IMAGE_AT_MIDDLE], [TEXT_ONLY, IMAGE_AT_BEG, IMAGE_AT_MIDDLE],
|
||||
[IMAGE_AT_MIDDLE, TEXT_ONLY], [TWO_IMAGES], [TEXT_ONLY, TWO_IMAGES]])
|
||||
def test_get_full_text_row_masked_out_mask(input_indices) -> None:
|
||||
|
||||
sequences = [torch.tensor(prompt_data[i]) for i in input_indices]
|
||||
|
||||
seq_lens = [len(s) for s in sequences]
|
||||
|
||||
num_prefill_tokens = sum(seq_lens)
|
||||
|
||||
# TEXT_ONLY is zero, so it will be masked out,
|
||||
# other instances should not be.
|
||||
encoder_seq_lens = [int(i) for i in input_indices]
|
||||
|
||||
attn_data = FlashAttentionMetadata(
|
||||
seq_lens=seq_lens,
|
||||
encoder_seq_lens=encoder_seq_lens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
# Dummy values
|
||||
enable_kv_scales_calculation=False,
|
||||
num_prefills=0,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=0,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
seq_lens_tensor=0,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=0,
|
||||
context_lens_tensor=None,
|
||||
block_tables=None,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
dummy: dict[str, str] = {}
|
||||
|
||||
full_text_row_masked_out_mask = MllamaForConditionalGeneration\
|
||||
.get_full_text_row_masked_out_mask(dummy,
|
||||
attn_data,
|
||||
torch.get_default_device())
|
||||
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask.squeeze()
|
||||
full_text_row_masked_out_mask = full_text_row_masked_out_mask.tolist()
|
||||
|
||||
idx = 0
|
||||
assert len(full_text_row_masked_out_mask) == num_prefill_tokens
|
||||
for i, seq_len in enumerate(seq_lens):
|
||||
must_be_masked = input_indices[i] != TEXT_ONLY
|
||||
for _ in range(seq_len):
|
||||
assert full_text_row_masked_out_mask[idx] == must_be_masked, \
|
||||
f"full_text_row_masked_out_mask[{idx}] must be " \
|
||||
f"'{must_be_masked}' "
|
||||
idx += 1
|
||||
|
||||
@ -1485,14 +1485,23 @@ def convert_sparse_cross_attention_mask_to_dense(
|
||||
total_length = sum(lengths)
|
||||
total_tiles = sum([sum(tiles) for tiles in num_tiles])
|
||||
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
|
||||
# A list of ranges, range[i] = [start, end] means
|
||||
# if the i-th sample has N tiles in total, the tiles[start, end]
|
||||
# will be used for cross-attention decoding.
|
||||
# A list of ranges, range[i] = [start, end] means that the i-th image will
|
||||
# use tiles[start, end] for cross-attention decoding.
|
||||
tile_range_for_decode = []
|
||||
|
||||
seq_start = 0
|
||||
tile_start = 0
|
||||
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
|
||||
|
||||
# sparse_mask has an [] entry for each sequence that does not have images,
|
||||
# but num_tiles does not have these entries...
|
||||
num_tiles_idx = 0
|
||||
for masks, length in zip(sparse_mask, lengths):
|
||||
if len(masks) == 0:
|
||||
# Text only
|
||||
continue
|
||||
|
||||
tiles = num_tiles[num_tiles_idx]
|
||||
num_tiles_idx += 1
|
||||
ts, td = -1, 0
|
||||
for mask, tile in zip(masks, tiles):
|
||||
if len(mask) != 2:
|
||||
@ -1512,6 +1521,7 @@ def convert_sparse_cross_attention_mask_to_dense(
|
||||
assert td != 0
|
||||
tile_range_for_decode.append((ts, ts + td))
|
||||
seq_start += length
|
||||
assert num_tiles_idx == len(num_tiles)
|
||||
|
||||
return dense_mask, tile_range_for_decode
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user