mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 05:35:01 +08:00
[Bugfix] Fix missing task for speculative decoding (#9524)
This commit is contained in:
parent
c5eea3c8ba
commit
263d8ee150
@ -33,8 +33,10 @@ logger = init_logger(__name__)
|
|||||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||||
|
|
||||||
Task = Literal["generate", "embedding"]
|
TaskOption = Literal["auto", "generate", "embedding"]
|
||||||
TaskOption = Literal["auto", Task]
|
|
||||||
|
# "draft" is only used internally for speculative decoding
|
||||||
|
_Task = Literal["generate", "embedding", "draft"]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
@ -115,7 +117,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: str,
|
model: str,
|
||||||
task: TaskOption,
|
task: Union[TaskOption, _Task],
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -255,18 +257,21 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _resolve_task(
|
def _resolve_task(
|
||||||
self,
|
self,
|
||||||
task_option: TaskOption,
|
task_option: Union[TaskOption, _Task],
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
) -> Tuple[Set[Task], Task]:
|
) -> Tuple[Set[_Task], _Task]:
|
||||||
|
if task_option == "draft":
|
||||||
|
return {"draft"}, "draft"
|
||||||
|
|
||||||
architectures = getattr(hf_config, "architectures", [])
|
architectures = getattr(hf_config, "architectures", [])
|
||||||
|
|
||||||
task_support: Dict[Task, bool] = {
|
task_support: Dict[_Task, bool] = {
|
||||||
# NOTE: Listed from highest to lowest priority,
|
# NOTE: Listed from highest to lowest priority,
|
||||||
# in case the model supports multiple of them
|
# in case the model supports multiple of them
|
||||||
"generate": ModelRegistry.is_text_generation_model(architectures),
|
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||||
"embedding": ModelRegistry.is_embedding_model(architectures),
|
"embedding": ModelRegistry.is_embedding_model(architectures),
|
||||||
}
|
}
|
||||||
supported_tasks_lst: List[Task] = [
|
supported_tasks_lst: List[_Task] = [
|
||||||
task for task, is_supported in task_support.items() if is_supported
|
task for task, is_supported in task_support.items() if is_supported
|
||||||
]
|
]
|
||||||
supported_tasks = set(supported_tasks_lst)
|
supported_tasks = set(supported_tasks_lst)
|
||||||
@ -1002,7 +1007,7 @@ class SchedulerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
task: Task,
|
task: _Task,
|
||||||
max_num_batched_tokens: Optional[int],
|
max_num_batched_tokens: Optional[int],
|
||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
@ -1269,7 +1274,7 @@ class SpeculativeConfig:
|
|||||||
ngram_prompt_lookup_min = 0
|
ngram_prompt_lookup_min = 0
|
||||||
draft_model_config = ModelConfig(
|
draft_model_config = ModelConfig(
|
||||||
model=speculative_model,
|
model=speculative_model,
|
||||||
task=target_model_config.task,
|
task="draft",
|
||||||
tokenizer=target_model_config.tokenizer,
|
tokenizer=target_model_config.tokenizer,
|
||||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||||
trust_remote_code=target_model_config.trust_remote_code,
|
trust_remote_code=target_model_config.trust_remote_code,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user