[Misc] Minor patch for draft model runner (#6523)

This commit is contained in:
Cody Yu 2024-07-17 23:06:21 -07:00 committed by GitHub
parent 61e592747c
commit 8a74c68bd1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -15,8 +15,12 @@ from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata,
logger = init_logger(__name__)
# A flag to enable debug prints for the updated input tensors
# before each step.
debug_advance_input = False
enable_gpu_advance_step = True
# A flag to allow GPU advance step for draft model runner.
# Set to False for debugging.
allow_gpu_advance_step = True
class TP1DraftModelRunner(ModelRunner):
@ -196,7 +200,7 @@ class TP1DraftModelRunner(ModelRunner):
3. No LORA
4. No prompt_adapter_config
"""
if not enable_gpu_advance_step:
if not allow_gpu_advance_step:
return False
# We allow multi-step GPU only in decode mode