mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 21:42:13 +08:00
dsd draft
This commit is contained in:
parent
f0ca3a6142
commit
50e2788383
@ -226,6 +226,7 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
args.backend = "vllm"
|
||||||
validate_dataset(args)
|
validate_dataset(args)
|
||||||
random.seed(0)
|
random.seed(0)
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@ -81,6 +81,7 @@ class RejectionSampler(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
output_token_ids (torch.Tensor):
|
output_token_ids (torch.Tensor):
|
||||||
A tensor containing the final output token IDs.
|
A tensor containing the final output token IDs.
|
||||||
|
acceptance_rate: min(p, q)
|
||||||
'''
|
'''
|
||||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||||
# [num_tokens, vocab_size]
|
# [num_tokens, vocab_size]
|
||||||
@ -92,7 +93,7 @@ class RejectionSampler(nn.Module):
|
|||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_token_ids = rejection_sample(
|
output_token_ids, output_probs = rejection_sample(
|
||||||
metadata.draft_token_ids,
|
metadata.draft_token_ids,
|
||||||
metadata.num_draft_tokens,
|
metadata.num_draft_tokens,
|
||||||
metadata.max_spec_len,
|
metadata.max_spec_len,
|
||||||
@ -102,7 +103,9 @@ class RejectionSampler(nn.Module):
|
|||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
return output_token_ids
|
mask = output_probs != PLACEHOLDER_TOKEN_ID
|
||||||
|
acceptance_rate = output_probs[mask].mean()
|
||||||
|
return output_token_ids, acceptance_rate
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def parse_output(
|
def parse_output(
|
||||||
@ -170,6 +173,8 @@ def rejection_sample(
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
output_token_ids.fill_(PLACEHOLDER_TOKEN_ID)
|
||||||
|
output_probs = torch.empty_like(output_token_ids, dtype=torch.float32)
|
||||||
|
output_probs.fill_(PLACEHOLDER_TOKEN_ID)
|
||||||
|
|
||||||
if sampling_metadata.all_greedy:
|
if sampling_metadata.all_greedy:
|
||||||
is_greedy = None
|
is_greedy = None
|
||||||
@ -180,6 +185,7 @@ def rejection_sample(
|
|||||||
target_argmax = target_probs.argmax(dim=-1)
|
target_argmax = target_probs.argmax(dim=-1)
|
||||||
rejection_greedy_sample_kernel[(batch_size, )](
|
rejection_greedy_sample_kernel[(batch_size, )](
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
|
output_probs,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
target_argmax,
|
target_argmax,
|
||||||
@ -216,6 +222,7 @@ def rejection_sample(
|
|||||||
# Rejection sampling for random sampling requests.
|
# Rejection sampling for random sampling requests.
|
||||||
rejection_random_sample_kernel[(batch_size, )](
|
rejection_random_sample_kernel[(batch_size, )](
|
||||||
output_token_ids,
|
output_token_ids,
|
||||||
|
output_probs,
|
||||||
cu_num_draft_tokens,
|
cu_num_draft_tokens,
|
||||||
draft_token_ids,
|
draft_token_ids,
|
||||||
draft_probs,
|
draft_probs,
|
||||||
@ -229,7 +236,7 @@ def rejection_sample(
|
|||||||
IS_NGRAM=draft_probs is None,
|
IS_NGRAM=draft_probs is None,
|
||||||
num_warps=1,
|
num_warps=1,
|
||||||
)
|
)
|
||||||
return output_token_ids
|
return output_token_ids, output_probs
|
||||||
|
|
||||||
|
|
||||||
def compute_probs(
|
def compute_probs(
|
||||||
@ -432,6 +439,7 @@ def sample_recovered_tokens(
|
|||||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
def rejection_greedy_sample_kernel(
|
def rejection_greedy_sample_kernel(
|
||||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||||
|
output_probs_ptr, # [batch_size, max_spec_len + 1]
|
||||||
cu_num_draft_tokens_ptr, # [batch_size]
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
draft_token_ids_ptr, # [num_tokens]
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
target_argmax_ptr, # [num_tokens]
|
target_argmax_ptr, # [num_tokens]
|
||||||
@ -459,14 +467,16 @@ def rejection_greedy_sample_kernel(
|
|||||||
|
|
||||||
rejected = False
|
rejected = False
|
||||||
for pos in range(num_draft_tokens):
|
for pos in range(num_draft_tokens):
|
||||||
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
|
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||||
if not rejected:
|
if not rejected:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
|
||||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
|
||||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
target_argmax_id)
|
target_argmax_id)
|
||||||
if draft_token_id != target_argmax_id:
|
if draft_token_id != target_argmax_id:
|
||||||
# Reject.
|
# Reject.
|
||||||
rejected = True
|
rejected = True
|
||||||
|
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
|
not rejected)
|
||||||
|
|
||||||
if not rejected:
|
if not rejected:
|
||||||
# If all tokens are accepted, append the bonus token.
|
# If all tokens are accepted, append the bonus token.
|
||||||
@ -480,6 +490,7 @@ def rejection_greedy_sample_kernel(
|
|||||||
@triton.jit(do_not_specialize=["max_spec_len"])
|
@triton.jit(do_not_specialize=["max_spec_len"])
|
||||||
def rejection_random_sample_kernel(
|
def rejection_random_sample_kernel(
|
||||||
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
|
||||||
|
output_probs_ptr, # [batch_size, max_spec_len + 1]
|
||||||
cu_num_draft_tokens_ptr, # [batch_size]
|
cu_num_draft_tokens_ptr, # [batch_size]
|
||||||
draft_token_ids_ptr, # [num_tokens]
|
draft_token_ids_ptr, # [num_tokens]
|
||||||
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
draft_probs_ptr, # [num_tokens, vocab_size] or None
|
||||||
@ -507,17 +518,16 @@ def rejection_random_sample_kernel(
|
|||||||
|
|
||||||
rejected = False
|
rejected = False
|
||||||
for pos in range(num_draft_tokens):
|
for pos in range(num_draft_tokens):
|
||||||
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
|
if IS_NGRAM:
|
||||||
|
draft_prob = 1
|
||||||
|
else:
|
||||||
|
draft_prob = tl.load(draft_probs_ptr +
|
||||||
|
(start_idx + pos) * vocab_size +
|
||||||
|
draft_token_id)
|
||||||
|
target_prob = tl.load(target_probs_ptr +
|
||||||
|
(start_idx + pos) * vocab_size + draft_token_id)
|
||||||
if not rejected:
|
if not rejected:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
|
||||||
if IS_NGRAM:
|
|
||||||
draft_prob = 1
|
|
||||||
else:
|
|
||||||
draft_prob = tl.load(draft_probs_ptr +
|
|
||||||
(start_idx + pos) * vocab_size +
|
|
||||||
draft_token_id)
|
|
||||||
target_prob = tl.load(target_probs_ptr +
|
|
||||||
(start_idx + pos) * vocab_size +
|
|
||||||
draft_token_id)
|
|
||||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||||
# NOTE(woosuk): While the draft probability should never be 0,
|
# NOTE(woosuk): While the draft probability should never be 0,
|
||||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||||
@ -530,6 +540,8 @@ def rejection_random_sample_kernel(
|
|||||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
token_id)
|
token_id)
|
||||||
|
tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||||
|
min(draft_prob, target_prob))
|
||||||
|
|
||||||
if not rejected:
|
if not rejected:
|
||||||
# If all tokens are accepted, append the bonus token.
|
# If all tokens are accepted, append the bonus token.
|
||||||
|
|||||||
133
vllm/v1/spec_decode/auto_tuner.py
Normal file
133
vllm/v1/spec_decode/auto_tuner.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTuner:
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Some tracking metrics
|
||||||
|
# for the auto-tuning process.
|
||||||
|
# metrics specific to ngram_proposer.
|
||||||
|
self.step_cnt = 0
|
||||||
|
self.match_cnt = 0
|
||||||
|
self.total_cnt = 0
|
||||||
|
self.past_acceptance_rates = []
|
||||||
|
self.past_match_ratios = []
|
||||||
|
|
||||||
|
# config
|
||||||
|
self.update_interval = 100
|
||||||
|
self.window_size = 10000
|
||||||
|
self.c_kv_load = 0.1
|
||||||
|
self.c_computation = 0.2
|
||||||
|
self.c_overhead = 0.3
|
||||||
|
|
||||||
|
# some cached values
|
||||||
|
self.last_verified_len = 0
|
||||||
|
|
||||||
|
def get_verified_len(self, batch_size: int, match_cnt: int,
|
||||||
|
num_kv_tokens: int, max_draft_len: int) -> int:
|
||||||
|
if self.step_cnt % self.update_interval != 0:
|
||||||
|
return self.last_verified_len
|
||||||
|
|
||||||
|
best_verified_len = 0
|
||||||
|
max_goodput = -1.0
|
||||||
|
for i in range(max_draft_len):
|
||||||
|
cur_goodput, draft_time, target_time = self._predict_goodput(
|
||||||
|
batch_size, match_cnt, num_kv_tokens, i)
|
||||||
|
# print(f"Goodput for proposal len {i}: {cur_goodput}")
|
||||||
|
if cur_goodput > max_goodput:
|
||||||
|
max_goodput = cur_goodput
|
||||||
|
best_verified_len = i
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.last_verified_len = best_verified_len
|
||||||
|
return best_verified_len
|
||||||
|
|
||||||
|
def adjust_draft_len(self, req_states: dict[str, CachedRequestState],
|
||||||
|
draft_token_ids: list[list[int]]):
|
||||||
|
"""
|
||||||
|
Adjust the draft length based on the verified length.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Calculate parameters used for goodput prediction.
|
||||||
|
num_kv_tokens = 0
|
||||||
|
for req_id in req_states:
|
||||||
|
num_kv_tokens += req_states[req_id].num_tokens
|
||||||
|
batch_size = len(draft_token_ids)
|
||||||
|
match_cnt = 0
|
||||||
|
max_draft_len = 0
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
if len(draft_token_ids[i]) == 0:
|
||||||
|
continue
|
||||||
|
match_cnt += 1
|
||||||
|
max_draft_len = max(max_draft_len, len(draft_token_ids[i]))
|
||||||
|
self.total_cnt += batch_size
|
||||||
|
self.match_cnt += match_cnt
|
||||||
|
self.past_match_ratios.append(match_cnt * 1.0 / (batch_size))
|
||||||
|
|
||||||
|
return draft_token_ids
|
||||||
|
# Use goodput prediction to get the verified length.
|
||||||
|
verified_len = self.get_verified_len(batch_size, match_cnt,
|
||||||
|
num_kv_tokens, max_draft_len)
|
||||||
|
|
||||||
|
draft_token_ids = [draft[:verified_len] for draft in draft_token_ids]
|
||||||
|
return draft_token_ids
|
||||||
|
|
||||||
|
def update_stats(self, acceptance_rate: float):
|
||||||
|
self.step_cnt += 1
|
||||||
|
if self.step_cnt % 20 == 0:
|
||||||
|
print(
|
||||||
|
f"Step {self.step_cnt}: "
|
||||||
|
f"Last acceptance rate: {acceptance_rate:.2f}",
|
||||||
|
f"Last match ratio: {self.past_match_ratios[-1]:.2f}",
|
||||||
|
f"Global acceptance rate: {self.acceptance_rate:.2f}",
|
||||||
|
("Global match ratio:",
|
||||||
|
f"{self.match_cnt / (self.total_cnt + 1e-5):.2f}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.past_acceptance_rates.append(acceptance_rate)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def acceptance_rate(self):
|
||||||
|
window_acceptance_rates = self.past_acceptance_rates[-self.
|
||||||
|
window_size:]
|
||||||
|
return sum(window_acceptance_rates) / len(window_acceptance_rates)
|
||||||
|
|
||||||
|
def _predict_goodput(self, batch_size: int, match_cnt: int,
|
||||||
|
num_kv_tokens: int,
|
||||||
|
verified_len: int) -> tuple[float, float, float]:
|
||||||
|
"""
|
||||||
|
Predict the goodput for a given verified length.
|
||||||
|
"""
|
||||||
|
generated_len = self._predict_generated_len(batch_size, match_cnt,
|
||||||
|
verified_len)
|
||||||
|
draft_time = self._predict_draft_time()
|
||||||
|
target_time = self._predict_target_time(batch_size, match_cnt,
|
||||||
|
num_kv_tokens, verified_len)
|
||||||
|
batch_time = draft_time + target_time
|
||||||
|
return generated_len / batch_time, draft_time, target_time
|
||||||
|
|
||||||
|
def _predict_generated_len(self, batch_size: int, match_cnt: int,
|
||||||
|
verified_len: int):
|
||||||
|
spec_gen_len = float((1 - self.acceptance_rate**(verified_len + 1)) /
|
||||||
|
(1 - self.acceptance_rate))
|
||||||
|
non_spec_gen_len = batch_size - match_cnt
|
||||||
|
return spec_gen_len + non_spec_gen_len
|
||||||
|
|
||||||
|
def _predict_draft_time(self):
|
||||||
|
# TODO: We need to benchmark and model this.
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _predict_target_time(self, batch_size: int, match_cnt: int,
|
||||||
|
num_kv_tokens: int, verified_len: int):
|
||||||
|
kv_load_time = num_kv_tokens * self.c_kv_load
|
||||||
|
|
||||||
|
# Computation time
|
||||||
|
# +1 for the input token.
|
||||||
|
num_batched_tokens = match_cnt * (verified_len + 1) + (batch_size -
|
||||||
|
match_cnt)
|
||||||
|
computation_time = num_batched_tokens * self.c_computation
|
||||||
|
|
||||||
|
return kv_load_time + computation_time + self.c_overhead
|
||||||
@ -34,6 +34,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
|||||||
ModelRunnerOutput)
|
ModelRunnerOutput)
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
from vllm.v1.sample.rejection_sampler import RejectionSampler
|
||||||
|
from vllm.v1.spec_decode.auto_tuner import AutoTuner
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||||
@ -156,6 +157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.use_spec_decode = False
|
self.use_spec_decode = False
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
self.use_spec_decode = True
|
self.use_spec_decode = True
|
||||||
|
self.auto_tuner = AutoTuner()
|
||||||
assert self.speculative_config.method == "ngram", \
|
assert self.speculative_config.method == "ngram", \
|
||||||
"Currently, only ngram spec decode is supported in V1."
|
"Currently, only ngram spec decode is supported in V1."
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
@ -1087,13 +1089,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# separate storage from the original `logits` tensor. Therefore,
|
# separate storage from the original `logits` tensor. Therefore,
|
||||||
# it is safe to update `target_logits` in place.
|
# it is safe to update `target_logits` in place.
|
||||||
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
target_logits = logits[spec_decode_metadata.target_logits_indices]
|
||||||
output_token_ids = self.rejection_sampler(
|
output_token_ids, acceptance_rate = self.rejection_sampler(
|
||||||
spec_decode_metadata,
|
spec_decode_metadata,
|
||||||
None, # draft_probs
|
None, # draft_probs
|
||||||
target_logits,
|
target_logits,
|
||||||
bonus_token_ids,
|
bonus_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
|
self.auto_tuner.update_stats(acceptance_rate)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
@ -1191,6 +1194,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
draft_token_ids.append([])
|
draft_token_ids.append([])
|
||||||
else:
|
else:
|
||||||
draft_token_ids.append(drafter_output.tolist())
|
draft_token_ids.append(drafter_output.tolist())
|
||||||
|
|
||||||
|
draft_token_ids = self.auto_tuner.adjust_draft_len(
|
||||||
|
self.requests, draft_token_ids)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user