[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (#9730)

Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
Sungjae Lee 2024-11-06 10:45:45 +09:00 committed by GitHub
parent 966e31697b
commit 0c63c34f72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 159 additions and 10 deletions

View File

@ -5,6 +5,8 @@ from unittest.mock import MagicMock
import pytest
import torch
from vllm.attention.selector import (_Backend,
global_force_attn_backend_context_manager)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output():
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
seed,
model_runner_cls=TP1DraftModelRunner,
)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(
Worker,
model_name,
@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output():
assert (num_mismatch > 0)
@torch.inference_mode()
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
# The choice of backends forces the multi_step_worker to choose between
# the vanilla model_runner and TP1DraftModelRunner and that we can test
# both code paths.
@pytest.mark.parametrize('attn_backend',
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
def test_multi_step_correct_kvcache(num_steps, attn_backend):
"""Verify that the KV cache of the draft model
is correctly updated for sequences with bonus token.
"""
seed = 100
model_name = "JackFram/llama-68m"
block_size = 16
num_gpu_blocks = 2048 // block_size
batch_size = 1
with global_force_attn_backend_context_manager(attn_backend):
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
multi_step_worker = create_worker(MultiStepWorker,
model_name,
block_size,
num_gpu_blocks,
seed,
model_runner_cls=TP1DraftModelRunner,
dtype=dtype)
multi_step_worker.set_include_gpu_probs_tensor()
worker = create_worker(Worker,
model_name,
block_size,
num_gpu_blocks,
seed,
dtype=dtype)
prompts = [[0] for _ in range(batch_size)]
# Already generate two tokens for the sequence
# so that we can simulate the bonus token case
multi_step_continuations = [[
random.randint(0, 1000),
random.randint(0, 1000)
] for _ in prompts]
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
# Run multi-step.
zero_kv_cache(multi_step_worker.cache_engine)
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list),
sample_len=num_steps,
seq_ids_with_bonus_token_in_last_step=
seq_ids_with_bonus_token_in_last_step)
# Run single-step repeatedly.
zero_kv_cache(worker.cache_engine)
# Generate the kv cache for the bonus token first
single_step_continuations = [c[:1] for c in multi_step_continuations]
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=single_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for _ in range(num_steps):
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
prompts,
num_gpu_blocks,
block_size,
continuations=multi_step_continuations,
final_prompt_lens=final_prompt_lens)
single_step_output = worker.execute_model(
execute_model_req=ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list))
for i, seq_group_output in enumerate(single_step_output[-1]):
multi_step_continuations[i].append(
seq_group_output.samples[0].output_token)
# Verify that the KV cache of the single-step and
# multi-step workers are the same.
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
num_layers = len(single_step_gpu_cache)
allclose = lambda a, b: torch.allclose(
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
for i in range(num_layers):
assert allclose(single_step_gpu_cache[i][0],
multi_step_gpu_cache[i][0])
assert allclose(single_step_gpu_cache[i][1],
multi_step_gpu_cache[i][1])
@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
"""Verify Top1Proposer correctly handles case where all sequences

View File

@ -68,12 +68,14 @@ def create_worker(cls: Callable[..., T],
seed: int,
is_driver_worker: bool = True,
enforce_eager: bool = True,
model_runner_cls: Optional[ModelRunner] = None) -> T:
model_runner_cls: Optional[ModelRunner] = None,
dtype: Optional[str] = "auto") -> T:
engine_args = EngineArgs(
model=model_name,
seed=seed,
block_size=block_size,
enforce_eager=enforce_eager,
dtype=dtype,
)
engine_config = engine_args.create_engine_config()

View File

@ -54,6 +54,8 @@ class TP1DraftModelRunner(ModelRunner):
super().__init__(*args, **kwargs)
self.indices_of_seq_with_bonus_tokens = None
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
num_queries):
@ -159,6 +161,10 @@ class TP1DraftModelRunner(ModelRunner):
# TODO: Add soft-tuning prompt adapter support
return not self.prompt_adapter_config
def set_indices_of_seq_with_bonus_tokens(self,
indices_of_seq_with_bonus_tokens):
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
@torch.inference_mode()
def execute_model(
self,
@ -284,11 +290,30 @@ class TP1DraftModelRunner(ModelRunner):
model_input.sampling_metadata)
# Sample the next token.
outputs.append(
self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
))
output = self.model.sample(
logits=logits,
sampling_metadata=model_input.sampling_metadata,
)
outputs.append(output)
if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None:
assert output.sampled_token_ids is not None
# output.sampled_token_ids should be of shape (num_seqs, 1)
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
assert num_tokens_per_seq == 1
count = 0
for i in range(nums_seqs):
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
count]
if i != bonus_seq_idx:
# The following might cause a cpu->gpu sync
# However, the performance impact is negligible as we
# benchmarked on H100.
output.sampled_token_ids[
i, :] = model_input.input_tokens[bonus_seq_idx]
else:
count += 1
# Prepare inputs for the next step
if step != num_steps - 1:

View File

@ -81,6 +81,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
# Here we run the draft_model_runner with multi-step prepare
# on the GPU directly
expanded_request.num_steps = sample_len
self.model_runner.set_indices_of_seq_with_bonus_tokens(
indices_of_seq_with_bonus_tokens)
model_outputs = self.execute_model(
execute_model_req=expanded_request)
else:
@ -97,7 +99,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
model_output = model_output[0]
self._append_new_tokens(
model_output, expanded_request.seq_group_metadata_list)
model_output, expanded_request.seq_group_metadata_list,
indices_of_seq_with_bonus_tokens)
model_outputs.append(model_output)
filtered_model_outputs = self._filter_model_output(
@ -221,13 +224,15 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
@staticmethod
def _append_new_tokens(
model_output: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
seq_group_metadata_list: List[SequenceGroupMetadata],
indices_of_seq_with_bonus_tokens: List[int]) -> None:
"""Given model output from a single run, append the tokens to the
sequences. This is normally done outside of the worker, but it is
required if the worker is to perform multiple forward passes.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, model_output):
count = 0
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
zip(seq_group_metadata_list, model_output)):
seq_group_metadata.is_prompt = False
for seq_output in sequence_group_outputs.samples:
@ -237,6 +242,16 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
# Determine the actual token ID to be generated,
# considering bonus tokens
if index != indices_of_seq_with_bonus_tokens[count]:
bonus_seq_metadata = seq_group_metadata_list[
indices_of_seq_with_bonus_tokens[count]]
_, bonus_token_seq_data = next(
iter(bonus_seq_metadata.seq_data.items()))
token_id = bonus_token_seq_data.output_token_ids[-1]
else:
count += 1
seq.append_token_id(token_id, token_logprob.logprob)
seq.update_num_computed_tokens(1)