mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 02:07:00 +08:00
[Spec Decode][CI] Add e2e test for examples/spec_decode.py and prevent breaking Acceptance Length (#24531)
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com> Co-authored-by: Roger Wang <hey@rogerw.io> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
78f892c373
commit
5acda4cc71
@ -329,6 +329,8 @@ steps:
|
|||||||
- python3 offline_inference/basic/classify.py
|
- python3 offline_inference/basic/classify.py
|
||||||
- python3 offline_inference/basic/embed.py
|
- python3 offline_inference/basic/embed.py
|
||||||
- python3 offline_inference/basic/score.py
|
- python3 offline_inference/basic/score.py
|
||||||
|
- python3 offline_inference/spec_decode.py --test --method eagle --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||||
|
- python3 offline_inference/spec_decode.py --test --method eagle3 --num_spec_tokens 3 --dataset-name hf --dataset-path philschmid/mt-bench --num-prompts 80 --temp 0 --top-p 1.0 --top-k -1 --tp 1 --enable-chunked-prefill --max-model-len 2048
|
||||||
|
|
||||||
- label: Platform Tests (CUDA) # 4min
|
- label: Platform Tests (CUDA) # 4min
|
||||||
timeout_in_minutes: 15
|
timeout_in_minutes: 15
|
||||||
|
|||||||
@ -49,6 +49,7 @@ def get_custom_mm_prompts(num_prompts):
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
parser = FlexibleArgumentParser()
|
parser = FlexibleArgumentParser()
|
||||||
add_dataset_parser(parser)
|
add_dataset_parser(parser)
|
||||||
|
parser.add_argument("--test", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--method",
|
"--method",
|
||||||
type=str,
|
type=str,
|
||||||
@ -60,6 +61,7 @@ def parse_args():
|
|||||||
parser.add_argument("--tp", type=int, default=1)
|
parser.add_argument("--tp", type=int, default=1)
|
||||||
parser.add_argument("--enforce-eager", action="store_true")
|
parser.add_argument("--enforce-eager", action="store_true")
|
||||||
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
parser.add_argument("--enable-chunked-prefill", action="store_true")
|
||||||
|
parser.add_argument("--max-model-len", type=int, default=16384)
|
||||||
parser.add_argument("--temp", type=float, default=0)
|
parser.add_argument("--temp", type=float, default=0)
|
||||||
parser.add_argument("--top-p", type=float, default=1.0)
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
parser.add_argument("--top-k", type=int, default=-1)
|
parser.add_argument("--top-k", type=int, default=-1)
|
||||||
@ -71,8 +73,7 @@ def parse_args():
|
|||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(args):
|
||||||
args = parse_args()
|
|
||||||
args.endpoint_type = "openai-chat"
|
args.endpoint_type = "openai-chat"
|
||||||
|
|
||||||
model_dir = args.model_dir
|
model_dir = args.model_dir
|
||||||
@ -134,7 +135,7 @@ def main():
|
|||||||
gpu_memory_utilization=0.8,
|
gpu_memory_utilization=0.8,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
disable_log_stats=False,
|
disable_log_stats=False,
|
||||||
max_model_len=16384,
|
max_model_len=args.max_model_len,
|
||||||
limit_mm_per_prompt={"image": 5},
|
limit_mm_per_prompt={"image": 5},
|
||||||
disable_chunked_mm_input=True,
|
disable_chunked_mm_input=True,
|
||||||
)
|
)
|
||||||
@ -198,6 +199,39 @@ def main():
|
|||||||
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
|
acceptance_rate = acceptance_counts[i] / num_drafts if num_drafts > 0 else 0
|
||||||
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
|
print(f"acceptance at token {i}: {acceptance_rate:.2f}")
|
||||||
|
|
||||||
|
return acceptance_length
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
args = parse_args()
|
||||||
|
acceptance_length = main(args)
|
||||||
|
|
||||||
|
if args.test:
|
||||||
|
# takes ~30s to run on 1xH100
|
||||||
|
assert args.method in ["eagle", "eagle3"]
|
||||||
|
assert args.tp == 1
|
||||||
|
assert args.num_spec_tokens == 3
|
||||||
|
assert args.dataset_name == "hf"
|
||||||
|
assert args.dataset_path == "philschmid/mt-bench"
|
||||||
|
assert args.num_prompts == 80
|
||||||
|
assert args.temp == 0
|
||||||
|
assert args.top_p == 1.0
|
||||||
|
assert args.top_k == -1
|
||||||
|
assert args.enable_chunked_prefill
|
||||||
|
|
||||||
|
# check acceptance length is within 2% of expected value
|
||||||
|
rtol = 0.02
|
||||||
|
expected_acceptance_length = 2.296 if args.method == "eagle" else 2.811
|
||||||
|
|
||||||
|
assert (
|
||||||
|
acceptance_length <= (1 + rtol) * expected_acceptance_length
|
||||||
|
and acceptance_length >= (1 - rtol) * expected_acceptance_length
|
||||||
|
), (
|
||||||
|
f"acceptance_length {acceptance_length} is not "
|
||||||
|
f"within {rtol * 100}% of {expected_acceptance_length}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Test passed! Expected AL: "
|
||||||
|
f"{expected_acceptance_length}, got {acceptance_length}"
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user