mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 00:06:06 +08:00
[Bugfix] Fix missing task for speculative decoding (#9524)
This commit is contained in:
parent
c5eea3c8ba
commit
263d8ee150
@ -33,8 +33,10 @@ logger = init_logger(__name__)
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||
|
||||
Task = Literal["generate", "embedding"]
|
||||
TaskOption = Literal["auto", Task]
|
||||
TaskOption = Literal["auto", "generate", "embedding"]
|
||||
|
||||
# "draft" is only used internally for speculative decoding
|
||||
_Task = Literal["generate", "embedding", "draft"]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
@ -115,7 +117,7 @@ class ModelConfig:
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
task: TaskOption,
|
||||
task: Union[TaskOption, _Task],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
@ -255,18 +257,21 @@ class ModelConfig:
|
||||
|
||||
def _resolve_task(
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
task_option: Union[TaskOption, _Task],
|
||||
hf_config: PretrainedConfig,
|
||||
) -> Tuple[Set[Task], Task]:
|
||||
) -> Tuple[Set[_Task], _Task]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
|
||||
architectures = getattr(hf_config, "architectures", [])
|
||||
|
||||
task_support: Dict[Task, bool] = {
|
||||
task_support: Dict[_Task, bool] = {
|
||||
# NOTE: Listed from highest to lowest priority,
|
||||
# in case the model supports multiple of them
|
||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||
"embedding": ModelRegistry.is_embedding_model(architectures),
|
||||
}
|
||||
supported_tasks_lst: List[Task] = [
|
||||
supported_tasks_lst: List[_Task] = [
|
||||
task for task, is_supported in task_support.items() if is_supported
|
||||
]
|
||||
supported_tasks = set(supported_tasks_lst)
|
||||
@ -1002,7 +1007,7 @@ class SchedulerConfig:
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
task: Task,
|
||||
task: _Task,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
@ -1269,7 +1274,7 @@ class SpeculativeConfig:
|
||||
ngram_prompt_lookup_min = 0
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
task=target_model_config.task,
|
||||
task="draft",
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user