diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index ccbc6c022f1f9..0fdd0f5e4d8f4 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -189,6 +189,9 @@ class BenchmarkDataset(ABC): """ if len(requests) < num_requests: random.seed(self.random_seed) + logger.info("Current number of requests: %d", len(requests)) + logger.info("Oversampled requests to reach %d total samples.", + num_requests) additional = random.choices(requests, k=num_requests - len(requests)) requests.extend(additional) @@ -793,7 +796,7 @@ class AIMODataset(HuggingFaceDataset): sampled_requests = [] dynamic_output = output_len is None - for item in self.data: + for i, item in enumerate(self.data): if len(sampled_requests) >= num_requests: break prompt, completion = item['problem'], item["solution"] diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1f65277e1bfeb..19528e417a4a0 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -57,9 +57,9 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=1.0, + temperature=0, top_p=1.0, - ignore_eos=True, + ignore_eos=False, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, )) @@ -123,9 +123,9 @@ def run_vllm_chat( sampling_params.append( SamplingParams( n=n, - temperature=1.0, + temperature=0, top_p=1.0, - ignore_eos=True, + ignore_eos=False, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, )) @@ -167,9 +167,9 @@ async def run_vllm_async( sampling_params.append( SamplingParams( n=n, - temperature=1.0, + temperature=0, top_p=1.0, - ignore_eos=True, + ignore_eos=False, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, )) diff --git a/benchmarks/run.sh b/benchmarks/run.sh new file mode 100644 index 0000000000000..8b35a237807a7 --- /dev/null +++ b/benchmarks/run.sh @@ -0,0 +1,136 @@ +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name sonnet \ +# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}' + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name sharegpt \ +# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}' + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name hf \ +# --dataset-path likaixin/InstructCoder \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}' + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name sonnet \ +# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}' + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name sharegpt \ +# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}' + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3-8B-Instruct \ +# --dataset-name hf \ +# --dataset-path likaixin/InstructCoder \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3-Instruct-8B", "num_speculative_tokens": 20}' + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name hf \ +# --dataset-path likaixin/InstructCoder \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}' + + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name sharegpt \ +# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}' + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name sonnet \ +# --dataset-path /data/lily/batch-sd/benchmarks/sonnet.txt \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle3", "model": "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}' + + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name hf \ +# --dataset-path likaixin/InstructCoder \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}' + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name sharegpt \ +# --dataset-path /data/lily/ShareGPT_V3_unfiltered_cleaned_split.json \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "eagle", "model": "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", "num_speculative_tokens": 20}' + +# python benchmarks/benchmark_throughput.py \ +# --model meta-llama/Meta-Llama-3.1-8B-Instruct \ +# --dataset-name hf \ +# --dataset-path likaixin/InstructCoder \ +# --prefix-len 0 \ +# --output-len 512 \ +# --num-prompts 200 \ +# --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}' + + +# python benchmarks/benchmark_throughput.py \ +# --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ +# --dataset-name hf \ +# --dataset-path AI-MO/aimo-validation-aime \ +# --prefix-len 0 \ +# --output-len 5120 \ +# --num-prompts 90 \ +# --speculative_config '{"method": "eagle3", "num_speculative_tokens": 20, "model": "yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B"}' + + + +python benchmarks/benchmark_throughput.py \ + --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \ + --dataset-name hf \ + --dataset-path AI-MO/aimo-validation-aime \ + --prefix-len 0 \ + --output-len 5120 \ + --num-prompts 90 \ + --speculative_config '{"method": "ngram", "num_speculative_tokens": 20, "prompt_lookup_min": 2, "prompt_lookup_max": 5}' + diff --git a/benchmarks/visualize/vis_acc.py b/benchmarks/visualize/vis_acc.py new file mode 100644 index 0000000000000..d28c13009f961 --- /dev/null +++ b/benchmarks/visualize/vis_acc.py @@ -0,0 +1,57 @@ +import json +import seaborn as sns +import matplotlib.pyplot as plt +from transformers import AutoTokenizer + + +model = "r1-distill-llama-8B" +MODEL_TO_NAMES = { + "r1-distill-llama-8B" : "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" +} +method = "ngram" +dataset = "aime" +datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl" +tokenizer = AutoTokenizer.from_pretrained(MODEL_TO_NAMES[model], use_fast=False) + +def cleanup(data): + # Remove the prefill phase + data = data[1:] + # Cap the maximum value to 10 + data = [min(x, 10) for x in data] + return data + +def load_data(datapath): + acceptance_stats = [] + with open(datapath, "r") as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + acceptance_stats.append(cleanup(data['acc'])) + print("Input:", tokenizer.decode(data['prompt_token_ids'])) + print("Output:", tokenizer.decode(data['generated_token_ids'])) + print("=============================================") + + # Pad the acceptance stats to the same length + max_length = max(len(stats) for stats in acceptance_stats) + for i in range(len(acceptance_stats)): + acceptance_stats[i] += [-2] * (max_length - len(acceptance_stats[i])) + + print(f"Load {len(acceptance_stats)} with max length {max_length}") + return acceptance_stats + +acceptance_stats = load_data(datapath) + + +fig, ax = plt.subplots() +sns.heatmap(acceptance_stats, cmap="YlGnBu") +plt.xlabel("Position") +plt.ylabel("Request ID") +# Add Y-axis labels on the right +ax2 = ax.twinx() +ax2.set_ylim(ax.get_ylim()) # Match y-axis range +ax2.set_yticks([]) # Remove right tick marks if undesired +ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label + + +plt.tight_layout() +plt.savefig(f"figures/{model}/{method}_{dataset}_acceptance_stats.png") diff --git a/benchmarks/visualize/vis_acc_diff.py b/benchmarks/visualize/vis_acc_diff.py new file mode 100644 index 0000000000000..1b45d4ccd44e6 --- /dev/null +++ b/benchmarks/visualize/vis_acc_diff.py @@ -0,0 +1,69 @@ +import json +import seaborn as sns +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap + +model = "llama3.1-8B" +dataset = "instructcode" +method1 = "eagle" +method2 = "eagle3" + +def get_datapath(method): + datapath = f"/data/lily/batch-sd/data/{model}/{method}_{dataset}_acceptance_stats.jsonl" + return datapath + +def cleanup(data): + # Remove the prefill phase + data = data[1:] + # Cap the maximum value to 10 + data = [min(x, 10) for x in data] + return data + +def load_data(datapath): + acceptance_stats = {} + with open(datapath, "r") as f: + lines = f.readlines() + for line in lines: + data = json.loads(line) + key = hash(tuple(data['prompt_token_ids'])) + acceptance_stats[key] = cleanup(data['acc']) + # Pad the acceptance stats to the same length + max_length = max(len(stats) for k, stats in acceptance_stats.items()) + + for key in acceptance_stats: + acceptance_stats[key] += [-2] * (max_length - len(acceptance_stats[key])) + + print(f"Load {len(acceptance_stats)} with max length {max_length} from {datapath}") + return acceptance_stats + +def diff(acceptance_stats1, acceptance_stats2): + diff = {} + for key in acceptance_stats1: + if key in acceptance_stats2: + diff[key] = [a - b for a, b in zip(acceptance_stats1[key], acceptance_stats2[key])] + return diff + +datapath_1 = get_datapath(method1) +datapath_2 = get_datapath(method2) +acceptance_stats_1 = load_data(datapath_1) +acceptance_stats_2 = load_data(datapath_2) +acceptance_stats_diff = diff(acceptance_stats_1, acceptance_stats_2) + +acceptance_stats = list(acceptance_stats_diff.values()) + + +fig, ax = plt.subplots() +colors = ["red", "white", "blue"] +custom_cmap = LinearSegmentedColormap.from_list("custom", colors, N=256) +sns.heatmap(acceptance_stats, cmap=custom_cmap, center=0) +plt.xlabel("Position") +plt.ylabel("Request ID") +# Add Y-axis labels on the right +ax2 = ax.twinx() +ax2.set_ylim(ax.get_ylim()) # Match y-axis range +ax2.set_yticks([]) # Remove right tick marks if undesired +ax2.set_ylabel("# of Accepted Tokens", labelpad=10) # Set right y-axis label +plt.title(f"Diff between {method2} - {method1} acceptance stats for {dataset}") + +plt.tight_layout() +plt.savefig(f"figures/{model}/diff_{method2}_{method1}_{dataset}_acceptance_stats.png") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 21711c9292f9f..945cdec624da6 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -28,6 +28,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager +import json logger = init_logger(__name__) @@ -632,6 +633,7 @@ class Scheduler(SchedulerInterface): logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens + self.acceptance_stats = model_runner_output.acceptance_stats new_running: list[Request] = [] outputs: list[EngineCoreOutput] = [] @@ -789,6 +791,18 @@ class Scheduler(SchedulerInterface): self._free_request(request) def _free_request(self, request: Request) -> None: + req_id = request.request_id + data = self.acceptance_stats.pop(req_id) + with open('acceptance_stats.jsonl', 'a') as f: + f.write(json.dumps({ + "id": req_id, + "acc": data, + "prompt_token_ids": request.prompt_token_ids, + "generated_token_ids": request.output_token_ids._x + })) + f.write('\n') + + assert request.is_finished() self.kv_cache_manager.free(request) self.kv_cache_manager.free_block_hashes(request) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 2732b933c28a0..45e7cd43641fc 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -99,6 +99,8 @@ class ModelRunnerOutput: # [prompt_len, num_prompt_logprobs] # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + + acceptance_stats: Optional[dict[str, list]] = None EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e3d8b94fe9d7e..ddb9233695220 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -49,6 +49,7 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +import json if TYPE_CHECKING: import xgrammar as xgr @@ -281,6 +282,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): device="cpu", pin_memory=self.pin_memory) self.seq_lens_np = self.seq_lens_cpu.numpy() + + self.acceptance_stats = {} def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -1004,7 +1007,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, torch.Tensor]: + ) -> Union[ModelRunnerOutput, torch.Tensor]: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): get_kv_transfer_group().bind_connector_metadata( @@ -1187,6 +1190,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): sampled_token_ids, self.input_batch.vocab_size, ) + for i, token_ids in enumerate(valid_sampled_token_ids): + req_id = self.input_batch.req_ids[i] + if req_id not in self.acceptance_stats: + self.acceptance_stats[req_id] = [] + self.acceptance_stats[req_id].append(len(token_ids)) + # Force 1 generated token per request. + for i, token_ids in enumerate(valid_sampled_token_ids): + valid_sampled_token_ids[i] = token_ids[:1] + # Mask out the sampled tokens that should not be sampled. for i in discard_sampled_tokens_req_indices: valid_sampled_token_ids[i].clear() @@ -1285,6 +1297,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + acceptance_stats=self.acceptance_stats, ) def generate_draft_token_ids(