mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 22:54:40 +08:00
[Spec Decode][CI] Add e2e test for examples/spec_decode.py and prevent breaking Acceptance Length (#24531)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
24e8222745
commit
867ecdd1c8
@ -329,6 +329,8 @@ steps:
|
||||
- python3 offline_inference/basic/classify.py
|
||||
- python3 offline_inference/basic/embed.py
|
||||
- python3 offline_inference/basic/score.py
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||
|
||||
- label: Platform Tests (CUDA) # 4min
|
||||
timeout_in_minutes: 15
|
||||
|
||||
@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts):
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
add_dataset_parser(parser)
|
||||
parser.add_argument("--test", action="store_true")
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
@ -60,6 +61,7 @@ def parse_args():
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--enforce-eager", action="store_true")
|
||||
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||
parser.add_argument("--max-model-len", type=int, default=16384)
|
||||
parser.add_argument("--temp", type=float, default=0)
|
||||
parser.add_argument("--top-p", type=float, default=1.0)
|
||||
parser.add_argument("--top-k", type=int, default=-1)
|
||||
@ -71,8 +73,7 @@ def parse_args():
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
def main(args):
|
||||
args.endpoint_type = "openai-chat"
|
||||
|
||||
model_dir = args.model_dir
|
||||
@ -134,7 +135,7 @@ def main():
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
max_model_len=args.max_model_len,
|
||||
limit_mm_per_prompt={"image": 5},
|
||||
disable_chunked_mm_input=True,
|
||||
)
|
||||
@ -198,6 +199,39 @@ def main():
|
||||
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
|
||||
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
|
||||
|
||||
return acceptance_length
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
args = parse_args()
|
||||
acceptance_length = main(args)
|
||||
|
||||
if args.test:
|
||||
# takes ~30s to run on 1xH100
|
||||
assert args.method in ["eagle", "eagle3"]
|
||||
assert args.tp == 1
|
||||
assert args.num_spec_tokens == 3
|
||||
assert args.dataset_name == "hf"
|
||||
assert args.dataset_path == "philschmid/mt-bench"
|
||||
assert args.num_prompts == 80
|
||||
assert args.temp == 0
|
||||
assert args.top_p == 1.0
|
||||
assert args.top_k == -1
|
||||
assert args.enable_chunked_prefill
|
||||
|
||||
# check acceptance length is within 2% of expected value
|
||||
rtol = 0.02
|
||||
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
|
||||
|
||||
assert (
|
||||
acceptance_length <= (1 + rtol) * expected_acceptance_length
|
||||
and acceptance_length >= (1 - rtol) * expected_acceptance_length
|
||||
), (
|
||||
f"acceptance_length {acceptance_length} is not "
|
||||
f"within {rtol * 100}% of {expected_acceptance_length}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Test passed! Expected AL: "
|
||||
f"{expected_acceptance_length}, got {acceptance_length}"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user