diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 5e85a92304d31..e03c724baac0b 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -226,6 +226,7 @@ if __name__ == "__main__": args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model + args.backend = "vllm" validate_dataset(args) random.seed(0) main(args) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 3cf7fde5cd0ec..a108cd7bf9a1e 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -81,6 +81,7 @@ class RejectionSampler(nn.Module): Returns: output_token_ids (torch.Tensor): A tensor containing the final output token IDs. + acceptance_rate: min(p, q) ''' assert metadata.max_spec_len <= MAX_SPEC_LEN # [num_tokens, vocab_size] @@ -92,7 +93,7 @@ class RejectionSampler(nn.Module): sampling_metadata, ) - output_token_ids = rejection_sample( + output_token_ids, output_probs = rejection_sample( metadata.draft_token_ids, metadata.num_draft_tokens, metadata.max_spec_len, @@ -102,7 +103,9 @@ class RejectionSampler(nn.Module): bonus_token_ids, sampling_metadata, ) - return output_token_ids + mask = output_probs != PLACEHOLDER_TOKEN_ID + acceptance_rate = output_probs[mask].mean() + return output_token_ids, acceptance_rate @staticmethod def parse_output( @@ -170,6 +173,8 @@ def rejection_sample( device=device, ) output_token_ids.fill_(PLACEHOLDER_TOKEN_ID) + output_probs = torch.empty_like(output_token_ids, dtype=torch.float32) + output_probs.fill_(PLACEHOLDER_TOKEN_ID) if sampling_metadata.all_greedy: is_greedy = None @@ -180,6 +185,7 @@ def rejection_sample( target_argmax = target_probs.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size, )]( output_token_ids, + output_probs, cu_num_draft_tokens, draft_token_ids, target_argmax, @@ -216,6 +222,7 @@ def rejection_sample( # Rejection sampling for random sampling requests. rejection_random_sample_kernel[(batch_size, )]( output_token_ids, + output_probs, cu_num_draft_tokens, draft_token_ids, draft_probs, @@ -229,7 +236,7 @@ def rejection_sample( IS_NGRAM=draft_probs is None, num_warps=1, ) - return output_token_ids + return output_token_ids, output_probs def compute_probs( @@ -432,6 +439,7 @@ def sample_recovered_tokens( @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_greedy_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] + output_probs_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] target_argmax_ptr, # [num_tokens] @@ -459,14 +467,16 @@ def rejection_greedy_sample_kernel( rejected = False for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, target_argmax_id) if draft_token_id != target_argmax_id: # Reject. rejected = True + tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, + not rejected) if not rejected: # If all tokens are accepted, append the bonus token. @@ -480,6 +490,7 @@ def rejection_greedy_sample_kernel( @triton.jit(do_not_specialize=["max_spec_len"]) def rejection_random_sample_kernel( output_token_ids_ptr, # [batch_size, max_spec_len + 1] + output_probs_ptr, # [batch_size, max_spec_len + 1] cu_num_draft_tokens_ptr, # [batch_size] draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None @@ -507,17 +518,16 @@ def rejection_random_sample_kernel( rejected = False for pos in range(num_draft_tokens): + draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) + if IS_NGRAM: + draft_prob = 1 + else: + draft_prob = tl.load(draft_probs_ptr + + (start_idx + pos) * vocab_size + + draft_token_id) + target_prob = tl.load(target_probs_ptr + + (start_idx + pos) * vocab_size + draft_token_id) if not rejected: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - if IS_NGRAM: - draft_prob = 1 - else: - draft_prob = tl.load(draft_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) - target_prob = tl.load(target_probs_ptr + - (start_idx + pos) * vocab_size + - draft_token_id) uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos) # NOTE(woosuk): While the draft probability should never be 0, # we check it to avoid NaNs. If it happens to be 0, we reject. @@ -530,6 +540,8 @@ def rejection_random_sample_kernel( token_id = tl.load(recovered_token_ids_ptr + start_idx + pos) tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id) + tl.store(output_probs_ptr + req_idx * (max_spec_len + 1) + pos, + min(draft_prob, target_prob)) if not rejected: # If all tokens are accepted, append the bonus token. diff --git a/vllm/v1/spec_decode/auto_tuner.py b/vllm/v1/spec_decode/auto_tuner.py new file mode 100644 index 0000000000000..d3f4b21d565d7 --- /dev/null +++ b/vllm/v1/spec_decode/auto_tuner.py @@ -0,0 +1,133 @@ +# SPDX-License-Identifier: Apache-2.0 +from vllm.v1.worker.gpu_input_batch import CachedRequestState + + +class AutoTuner: + + def __init__(self): + # Some tracking metrics + # for the auto-tuning process. + # metrics specific to ngram_proposer. + self.step_cnt = 0 + self.match_cnt = 0 + self.total_cnt = 0 + self.past_acceptance_rates = [] + self.past_match_ratios = [] + + # config + self.update_interval = 100 + self.window_size = 10000 + self.c_kv_load = 0.1 + self.c_computation = 0.2 + self.c_overhead = 0.3 + + # some cached values + self.last_verified_len = 0 + + def get_verified_len(self, batch_size: int, match_cnt: int, + num_kv_tokens: int, max_draft_len: int) -> int: + if self.step_cnt % self.update_interval != 0: + return self.last_verified_len + + best_verified_len = 0 + max_goodput = -1.0 + for i in range(max_draft_len): + cur_goodput, draft_time, target_time = self._predict_goodput( + batch_size, match_cnt, num_kv_tokens, i) + # print(f"Goodput for proposal len {i}: {cur_goodput}") + if cur_goodput > max_goodput: + max_goodput = cur_goodput + best_verified_len = i + else: + break + + self.last_verified_len = best_verified_len + return best_verified_len + + def adjust_draft_len(self, req_states: dict[str, CachedRequestState], + draft_token_ids: list[list[int]]): + """ + Adjust the draft length based on the verified length. + """ + + # Calculate parameters used for goodput prediction. + num_kv_tokens = 0 + for req_id in req_states: + num_kv_tokens += req_states[req_id].num_tokens + batch_size = len(draft_token_ids) + match_cnt = 0 + max_draft_len = 0 + + for i in range(batch_size): + if len(draft_token_ids[i]) == 0: + continue + match_cnt += 1 + max_draft_len = max(max_draft_len, len(draft_token_ids[i])) + self.total_cnt += batch_size + self.match_cnt += match_cnt + self.past_match_ratios.append(match_cnt * 1.0 / (batch_size)) + + return draft_token_ids + # Use goodput prediction to get the verified length. + verified_len = self.get_verified_len(batch_size, match_cnt, + num_kv_tokens, max_draft_len) + + draft_token_ids = [draft[:verified_len] for draft in draft_token_ids] + return draft_token_ids + + def update_stats(self, acceptance_rate: float): + self.step_cnt += 1 + if self.step_cnt % 20 == 0: + print( + f"Step {self.step_cnt}: " + f"Last acceptance rate: {acceptance_rate:.2f}", + f"Last match ratio: {self.past_match_ratios[-1]:.2f}", + f"Global acceptance rate: {self.acceptance_rate:.2f}", + ("Global match ratio:", + f"{self.match_cnt / (self.total_cnt + 1e-5):.2f}"), + ) + + self.past_acceptance_rates.append(acceptance_rate) + + @property + def acceptance_rate(self): + window_acceptance_rates = self.past_acceptance_rates[-self. + window_size:] + return sum(window_acceptance_rates) / len(window_acceptance_rates) + + def _predict_goodput(self, batch_size: int, match_cnt: int, + num_kv_tokens: int, + verified_len: int) -> tuple[float, float, float]: + """ + Predict the goodput for a given verified length. + """ + generated_len = self._predict_generated_len(batch_size, match_cnt, + verified_len) + draft_time = self._predict_draft_time() + target_time = self._predict_target_time(batch_size, match_cnt, + num_kv_tokens, verified_len) + batch_time = draft_time + target_time + return generated_len / batch_time, draft_time, target_time + + def _predict_generated_len(self, batch_size: int, match_cnt: int, + verified_len: int): + spec_gen_len = float((1 - self.acceptance_rate**(verified_len + 1)) / + (1 - self.acceptance_rate)) + non_spec_gen_len = batch_size - match_cnt + return spec_gen_len + non_spec_gen_len + + def _predict_draft_time(self): + # TODO: We need to benchmark and model this. + return 0 + + def _predict_target_time(self, batch_size: int, match_cnt: int, + num_kv_tokens: int, verified_len: int): + kv_load_time = num_kv_tokens * self.c_kv_load + + # Computation time + # +1 for the input token. + num_batched_tokens = match_cnt * (verified_len + 1) + (batch_size - + match_cnt) + computation_time = num_batched_tokens * self.c_computation + + return kv_load_time + computation_time + self.c_overhead diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c7374cc3d3306..9702278c3bb8d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -34,6 +34,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, ModelRunnerOutput) from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.rejection_sampler import RejectionSampler +from vllm.v1.spec_decode.auto_tuner import AutoTuner from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.ngram_proposer import NgramProposer from vllm.v1.spec_decode.utils import is_spec_decode_supported @@ -156,6 +157,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.use_spec_decode = False if self.speculative_config: self.use_spec_decode = True + self.auto_tuner = AutoTuner() assert self.speculative_config.method == "ngram", \ "Currently, only ngram spec decode is supported in V1." if get_pp_group().is_last_rank: @@ -1087,13 +1089,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): # separate storage from the original `logits` tensor. Therefore, # it is safe to update `target_logits` in place. target_logits = logits[spec_decode_metadata.target_logits_indices] - output_token_ids = self.rejection_sampler( + output_token_ids, acceptance_rate = self.rejection_sampler( spec_decode_metadata, None, # draft_probs target_logits, bonus_token_ids, sampling_metadata, ) + self.auto_tuner.update_stats(acceptance_rate) sampler_output.sampled_token_ids = output_token_ids # TODO(woosuk): The following loop can be slow since it iterates over @@ -1191,6 +1194,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): draft_token_ids.append([]) else: draft_token_ids.append(drafter_output.tolist()) + + draft_token_ids = self.auto_tuner.adjust_draft_len( + self.requests, draft_token_ids) return draft_token_ids def load_model(self) -> None: