diff --git a/vllm/config.py b/vllm/config.py index 7f8f93642854..f57aa4048ae9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -33,8 +33,10 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 -Task = Literal["generate", "embedding"] -TaskOption = Literal["auto", Task] +TaskOption = Literal["auto", "generate", "embedding"] + +# "draft" is only used internally for speculative decoding +_Task = Literal["generate", "embedding", "draft"] class ModelConfig: @@ -115,7 +117,7 @@ class ModelConfig: def __init__(self, model: str, - task: TaskOption, + task: Union[TaskOption, _Task], tokenizer: str, tokenizer_mode: str, trust_remote_code: bool, @@ -255,18 +257,21 @@ class ModelConfig: def _resolve_task( self, - task_option: TaskOption, + task_option: Union[TaskOption, _Task], hf_config: PretrainedConfig, - ) -> Tuple[Set[Task], Task]: + ) -> Tuple[Set[_Task], _Task]: + if task_option == "draft": + return {"draft"}, "draft" + architectures = getattr(hf_config, "architectures", []) - task_support: Dict[Task, bool] = { + task_support: Dict[_Task, bool] = { # NOTE: Listed from highest to lowest priority, # in case the model supports multiple of them "generate": ModelRegistry.is_text_generation_model(architectures), "embedding": ModelRegistry.is_embedding_model(architectures), } - supported_tasks_lst: List[Task] = [ + supported_tasks_lst: List[_Task] = [ task for task, is_supported in task_support.items() if is_supported ] supported_tasks = set(supported_tasks_lst) @@ -1002,7 +1007,7 @@ class SchedulerConfig: """ def __init__(self, - task: Task, + task: _Task, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, @@ -1269,7 +1274,7 @@ class SpeculativeConfig: ngram_prompt_lookup_min = 0 draft_model_config = ModelConfig( model=speculative_model, - task=target_model_config.task, + task="draft", tokenizer=target_model_config.tokenizer, tokenizer_mode=target_model_config.tokenizer_mode, trust_remote_code=target_model_config.trust_remote_code,