[Bugfix] [Encoder-Decoder] Bugfix for encoder specific metadata construction during decode of encoder-decoder models. (#8545)

This commit is contained in:
sroy745 2024-09-18 19:24:15 -07:00 committed by GitHub
parent 4c34ce8916
commit 3118f63385
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 69 additions and 31 deletions

View File

@ -273,7 +273,8 @@ def test_prepare_prompt(batch_size):
"unsupported for encoder/ " "unsupported for encoder/ "
"decoder models") "decoder models")
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
def test_prepare_decode(batch_size): @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
''' '''
Test the ability of the encoder/decoder model runner subclass to Test the ability of the encoder/decoder model runner subclass to
produce decode-phase model inputs & attention metadata. produce decode-phase model inputs & attention metadata.
@ -288,6 +289,7 @@ def test_prepare_decode(batch_size):
Arguments: Arguments:
* batch_size * batch_size
* multiple_seqs_per_seq_group
* backend_name: The attention backend under test * backend_name: The attention backend under test
* enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph) * enforce_eager: Enforce eager mode if True (i.e. no CUDAGraph)
''' '''
@ -305,22 +307,29 @@ def test_prepare_decode(batch_size):
seq_lens: List[int] = [] seq_lens: List[int] = []
encoder_seq_lens: List[int] = [] encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]} block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
cross_block_table = [2] cross_block_table = [2]
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
seq_data={0: seq_data}, seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables=block_tables, block_tables=block_tables,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
@ -328,6 +337,10 @@ def test_prepare_decode(batch_size):
) )
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
# Build # Build
# * Decoder model inputs # * Decoder model inputs
@ -398,19 +411,24 @@ def test_prepare_decode(batch_size):
# Verify block tables are correct for prompts # Verify block tables are correct for prompts
# - Decoder self-attention # - Decoder self-attention
expected = torch.tensor( flattened_block_tables = [
[block_tables[0] for _ in range(len(seq_group_metadata_list))], block_table for block_table in block_tables.values()
dtype=torch.int32, ]
device=model_runner.device) expected = torch.tensor(flattened_block_tables *
len(seq_group_metadata_list),
dtype=torch.int32,
device=model_runner.device)
assert torch.equal( assert torch.equal(
attn_metadata.block_tables, attn_metadata.block_tables,
expected, expected,
) )
# - Encoder/decoder cross-attention # - Encoder/decoder cross-attention
expected = torch.tensor( expected = torch.tensor([
[cross_block_table for _ in range(len(seq_group_metadata_list))], cross_block_table for seq_group_metadata in seq_group_metadata_list
dtype=torch.int32, for _ in range(len(seq_group_metadata.seq_data))
device=model_runner.device) ],
dtype=torch.int32,
device=model_runner.device)
assert torch.equal( assert torch.equal(
attn_metadata.cross_block_tables, attn_metadata.cross_block_tables,
expected, expected,
@ -474,7 +492,8 @@ def test_prepare_decode(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257))) @pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size): @pytest.mark.parametrize("multiple_seqs_per_seq_group", [True, False])
def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
""" """
Tests that for encoder-decoder models with CUDA Graph capture and replay Tests that for encoder-decoder models with CUDA Graph capture and replay
enabled, the tensors used during the decode phase are correctly padded enabled, the tensors used during the decode phase are correctly padded
@ -489,32 +508,45 @@ def test_prepare_decode_cuda_graph(batch_size):
enable_chunked_prefill=False, enable_chunked_prefill=False,
enforce_eager=False, enforce_eager=False,
) )
block_tables = {
0: [1],
1: [3]
} if multiple_seqs_per_seq_group else {
0: [1]
}
seq_lens: List[int] = [] seq_lens: List[int] = []
encoder_seq_lens: List[int] = [] encoder_seq_lens: List[int] = []
seq_group_metadata_list: List[SequenceGroupMetadata] = [] seq_group_metadata_list: List[SequenceGroupMetadata] = []
block_tables = {0: [1]}
cross_block_table = [2] cross_block_table = [2]
expanded_batch_size = 0
for i in range(batch_size): for i in range(batch_size):
# make sure all tokens fit into one block # make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1 seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData( seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(seq_len))))
encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1 encoder_seq_len = (i + 1) % (model_runner.block_size - 1) + 1
encoder_seq_lens.append(encoder_seq_len)
encoder_seq_data = SequenceData( encoder_seq_data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len)))) array(VLLM_TOKEN_ID_ARRAY_TYPE, (range(encoder_seq_len))))
seq_group_metadata = SequenceGroupMetadata( seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}", request_id=f"test_{i}",
is_prompt=False, is_prompt=False,
seq_data={0: seq_data}, seq_data={
0: seq_data,
1: seq_data
} if multiple_seqs_per_seq_group else {0: seq_data},
sampling_params=SamplingParams(temperature=0), sampling_params=SamplingParams(temperature=0),
block_tables=block_tables, block_tables=block_tables,
encoder_seq_data=encoder_seq_data, encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table, cross_block_table=cross_block_table,
) )
assert seq_group_metadata.token_chunk_size == 1 assert seq_group_metadata.token_chunk_size == 1
seq_lens.extend(
[seq_len for _ in range(len(seq_group_metadata.seq_data))])
encoder_seq_lens.extend(
[encoder_seq_len for _ in range(len(seq_group_metadata.seq_data))])
expanded_batch_size = expanded_batch_size + len(
seq_group_metadata.seq_data)
seq_group_metadata_list.append(seq_group_metadata) seq_group_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list) model_input = model_runner.prepare_model_input(seq_group_metadata_list)
@ -530,8 +562,8 @@ def test_prepare_decode_cuda_graph(batch_size):
# With CUDA Graph capture and replay enabled, the decoder and encoder # With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors # input sequences will be padded. Create the expected padded tensors
# accordingly. # accordingly.
graph_batch_size = _get_graph_batch_size(batch_size) graph_batch_size = _get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - batch_size cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list( padded_encoder_seq_lens = encoder_seq_lens + list(
itertools.repeat(1, cuda_graph_pad_size)) itertools.repeat(1, cuda_graph_pad_size))
@ -560,10 +592,13 @@ def test_prepare_decode_cuda_graph(batch_size):
# Verify block tables are correct for prompts # Verify block tables are correct for prompts
# - Decoder self-attention. Pad the block tables as expected. # - Decoder self-attention. Pad the block tables as expected.
expected = [block_tables[0] for _ in range(batch_size)] flattened_block_tables = [
expected.extend([[] for _ in range(cuda_graph_pad_size)]) block_table for _ in range(len(seq_group_metadata_list))
for block_table in block_tables.values()
]
flattened_block_tables.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad( expected = make_tensor_with_pad(
expected, flattened_block_tables,
max_len=64, max_len=64,
pad=0, pad=0,
dtype=torch.int32, dtype=torch.int32,
@ -575,7 +610,10 @@ def test_prepare_decode_cuda_graph(batch_size):
) )
# - Encoder/decoder cross-attention. Pad the cross-attention block tables # - Encoder/decoder cross-attention. Pad the cross-attention block tables
# as expected. # as expected.
expected = [cross_block_table for _ in range(len(seq_group_metadata_list))] expected = [
cross_block_table for seq_group_metadata in seq_group_metadata_list
for _ in range(len(seq_group_metadata.seq_data))
]
expected.extend([[] for _ in range(cuda_graph_pad_size)]) expected.extend([[] for _ in range(cuda_graph_pad_size)])
expected = make_tensor_with_pad( expected = make_tensor_with_pad(
expected, expected,

View File

@ -435,18 +435,18 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
encoder_input_tokens_tensor = self._empty_long_tensor() encoder_input_tokens_tensor = self._empty_long_tensor()
encoder_input_positions_tensor = self._empty_long_tensor() encoder_input_positions_tensor = self._empty_long_tensor()
cross_slot_mapping_tensor = self._empty_long_tensor() cross_slot_mapping_tensor = self._empty_long_tensor()
# Extract cross-attention block tables & # Extract cross-attention block tables &
# seq len from each sequence group metadata. # seq len from each sequence group metadata.
# Cross-attention block tables are empty # Cross-attention block tables are empty
# during vLLM memory profiling. # during vLLM memory profiling.
cross_block_tables = [] cross_block_tables = []
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:
encoder_seq_lens.append( for _ in range(len(seq_group_metadata.seq_data)):
seq_group_metadata.encoder_seq_data.get_len()) encoder_seq_lens.append(
cross_block_table = seq_group_metadata.cross_block_table seq_group_metadata.encoder_seq_data.get_len())
cross_block_tables.append([] if ( cross_block_table = seq_group_metadata.cross_block_table
cross_block_table is None) else cross_block_table) cross_block_tables.append([] if (
cross_block_table is None) else cross_block_table)
if (model_input.attn_metadata is not None if (model_input.attn_metadata is not None
and model_input.attn_metadata.use_cuda_graph): and model_input.attn_metadata.use_cuda_graph):