mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:26:00 +08:00
[Bugfix][SpecDecode] kv corruption with bonus tokens in spec decode (#9730)
Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
This commit is contained in:
parent
966e31697b
commit
0c63c34f72
@ -5,6 +5,8 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.selector import (_Backend,
|
||||
global_force_attn_backend_context_manager)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||
@ -303,6 +305,7 @@ def test_multi_step_with_batch_expansion_correct_output():
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
@ -397,6 +400,7 @@ def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
@ -477,6 +481,109 @@ def test_multi_step_with_batch_expansion_incorrect_output():
|
||||
assert (num_mismatch > 0)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize('num_steps', [1, 2, 3, 4])
|
||||
# The choice of backends forces the multi_step_worker to choose between
|
||||
# the vanilla model_runner and TP1DraftModelRunner and that we can test
|
||||
# both code paths.
|
||||
@pytest.mark.parametrize('attn_backend',
|
||||
[_Backend.XFORMERS, _Backend.FLASH_ATTN])
|
||||
def test_multi_step_correct_kvcache(num_steps, attn_backend):
|
||||
"""Verify that the KV cache of the draft model
|
||||
is correctly updated for sequences with bonus token.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = "JackFram/llama-68m"
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
batch_size = 1
|
||||
|
||||
with global_force_attn_backend_context_manager(attn_backend):
|
||||
dtype = 'float16' if attn_backend == _Backend.FLASH_ATTN else 'float32'
|
||||
multi_step_worker = create_worker(MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
model_runner_cls=TP1DraftModelRunner,
|
||||
dtype=dtype)
|
||||
multi_step_worker.set_include_gpu_probs_tensor()
|
||||
worker = create_worker(Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
dtype=dtype)
|
||||
|
||||
prompts = [[0] for _ in range(batch_size)]
|
||||
# Already generate two tokens for the sequence
|
||||
# so that we can simulate the bonus token case
|
||||
multi_step_continuations = [[
|
||||
random.randint(0, 1000),
|
||||
random.randint(0, 1000)
|
||||
] for _ in prompts]
|
||||
final_prompt_lens = [len(prompt) + 2 + num_steps for prompt in prompts]
|
||||
|
||||
seq_ids_with_bonus_token_in_last_step = set(range(batch_size))
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
multi_step_worker.sampler_output(execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list),
|
||||
sample_len=num_steps,
|
||||
seq_ids_with_bonus_token_in_last_step=
|
||||
seq_ids_with_bonus_token_in_last_step)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
# Generate the kv cache for the bonus token first
|
||||
single_step_continuations = [c[:1] for c in multi_step_continuations]
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=single_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
for _ in range(num_steps):
|
||||
seq_group_metadata_list = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=multi_step_continuations,
|
||||
final_prompt_lens=final_prompt_lens)
|
||||
|
||||
single_step_output = worker.execute_model(
|
||||
execute_model_req=ExecuteModelRequest(
|
||||
seq_group_metadata_list=seq_group_metadata_list))
|
||||
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
multi_step_continuations[i].append(
|
||||
seq_group_output.samples[0].output_token)
|
||||
|
||||
# Verify that the KV cache of the single-step and
|
||||
# multi-step workers are the same.
|
||||
single_step_gpu_cache = worker.cache_engine[0].gpu_cache
|
||||
multi_step_gpu_cache = multi_step_worker.cache_engine[0].gpu_cache
|
||||
num_layers = len(single_step_gpu_cache)
|
||||
allclose = lambda a, b: torch.allclose(
|
||||
a.cuda(), b.cuda(), rtol=1e-2, atol=1e-2)
|
||||
for i in range(num_layers):
|
||||
assert allclose(single_step_gpu_cache[i][0],
|
||||
multi_step_gpu_cache[i][0])
|
||||
assert allclose(single_step_gpu_cache[i][1],
|
||||
multi_step_gpu_cache[i][1])
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_draft_proposals_full_speculation_len():
|
||||
"""Verify Top1Proposer correctly handles case where all sequences
|
||||
|
||||
@ -68,12 +68,14 @@ def create_worker(cls: Callable[..., T],
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True,
|
||||
model_runner_cls: Optional[ModelRunner] = None) -> T:
|
||||
model_runner_cls: Optional[ModelRunner] = None,
|
||||
dtype: Optional[str] = "auto") -> T:
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
|
||||
@ -54,6 +54,8 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.indices_of_seq_with_bonus_tokens = None
|
||||
|
||||
def _update_sampling_metadata(self, sampling_metadata, num_seqs,
|
||||
num_queries):
|
||||
|
||||
@ -159,6 +161,10 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
# TODO: Add soft-tuning prompt adapter support
|
||||
return not self.prompt_adapter_config
|
||||
|
||||
def set_indices_of_seq_with_bonus_tokens(self,
|
||||
indices_of_seq_with_bonus_tokens):
|
||||
self.indices_of_seq_with_bonus_tokens = indices_of_seq_with_bonus_tokens
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -284,11 +290,30 @@ class TP1DraftModelRunner(ModelRunner):
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Sample the next token.
|
||||
outputs.append(
|
||||
self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
))
|
||||
output = self.model.sample(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||
assert output.sampled_token_ids is not None
|
||||
# output.sampled_token_ids should be of shape (num_seqs, 1)
|
||||
nums_seqs, num_tokens_per_seq = output.sampled_token_ids.shape
|
||||
assert num_tokens_per_seq == 1
|
||||
count = 0
|
||||
for i in range(nums_seqs):
|
||||
bonus_seq_idx = self.indices_of_seq_with_bonus_tokens[
|
||||
count]
|
||||
if i != bonus_seq_idx:
|
||||
# The following might cause a cpu->gpu sync
|
||||
# However, the performance impact is negligible as we
|
||||
# benchmarked on H100.
|
||||
output.sampled_token_ids[
|
||||
i, :] = model_input.input_tokens[bonus_seq_idx]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
# Prepare inputs for the next step
|
||||
if step != num_steps - 1:
|
||||
|
||||
@ -81,6 +81,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
# Here we run the draft_model_runner with multi-step prepare
|
||||
# on the GPU directly
|
||||
expanded_request.num_steps = sample_len
|
||||
self.model_runner.set_indices_of_seq_with_bonus_tokens(
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs = self.execute_model(
|
||||
execute_model_req=expanded_request)
|
||||
else:
|
||||
@ -97,7 +99,8 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
model_output = model_output[0]
|
||||
|
||||
self._append_new_tokens(
|
||||
model_output, expanded_request.seq_group_metadata_list)
|
||||
model_output, expanded_request.seq_group_metadata_list,
|
||||
indices_of_seq_with_bonus_tokens)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
filtered_model_outputs = self._filter_model_output(
|
||||
@ -221,13 +224,15 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
@staticmethod
|
||||
def _append_new_tokens(
|
||||
model_output: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
indices_of_seq_with_bonus_tokens: List[int]) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
count = 0
|
||||
for index, (seq_group_metadata, sequence_group_outputs) in enumerate(
|
||||
zip(seq_group_metadata_list, model_output)):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
@ -237,6 +242,16 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
# Determine the actual token ID to be generated,
|
||||
# considering bonus tokens
|
||||
if index != indices_of_seq_with_bonus_tokens[count]:
|
||||
bonus_seq_metadata = seq_group_metadata_list[
|
||||
indices_of_seq_with_bonus_tokens[count]]
|
||||
_, bonus_token_seq_data = next(
|
||||
iter(bonus_seq_metadata.seq_data.items()))
|
||||
token_id = bonus_token_seq_data.output_token_ids[-1]
|
||||
else:
|
||||
count += 1
|
||||
|
||||
seq.append_token_id(token_id, token_logprob.logprob)
|
||||
seq.update_num_computed_tokens(1)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user