[Bugfix] Fix missing task for speculative decoding (#9524)

This commit is contained in:
Cyrus Leung 2024-10-19 14:49:40 +08:00 committed by GitHub
parent c5eea3c8ba
commit 263d8ee150
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,8 +33,10 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
Task = Literal["generate", "embedding"]
TaskOption = Literal["auto", Task]
TaskOption = Literal["auto", "generate", "embedding"]
# "draft" is only used internally for speculative decoding
_Task = Literal["generate", "embedding", "draft"]
class ModelConfig:
@ -115,7 +117,7 @@ class ModelConfig:
def __init__(self,
model: str,
task: TaskOption,
task: Union[TaskOption, _Task],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
@ -255,18 +257,21 @@ class ModelConfig:
def _resolve_task(
self,
task_option: TaskOption,
task_option: Union[TaskOption, _Task],
hf_config: PretrainedConfig,
) -> Tuple[Set[Task], Task]:
) -> Tuple[Set[_Task], _Task]:
if task_option == "draft":
return {"draft"}, "draft"
architectures = getattr(hf_config, "architectures", [])
task_support: Dict[Task, bool] = {
task_support: Dict[_Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[Task] = [
supported_tasks_lst: List[_Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
@ -1002,7 +1007,7 @@ class SchedulerConfig:
"""
def __init__(self,
task: Task,
task: _Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
@ -1269,7 +1274,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task=target_model_config.task,
task="draft",
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,