mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 06:06:33 +08:00
[Core] Support multiple tasks per model (#20771)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
c1acd6d7d4
commit
020f58abcd
@ -54,7 +54,7 @@ def test_get_field():
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
||||
("openai/whisper-small", "transcription", "transcription"),
|
||||
("openai/whisper-small", "generate", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||
@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||
)
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
if config.runner_type == "pooling":
|
||||
assert config.task == expected_task
|
||||
else:
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
|
||||
])
|
||||
def test_draft_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
runner="draft",
|
||||
tokenizer=model_id,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
("openai/whisper-small", "generate", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_transcription_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="transcription",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
|
||||
("Qwen/Qwen3-0.6B", "transcription"),
|
||||
])
|
||||
def test_incorrect_task(model_id, bad_task):
|
||||
with pytest.raises(ValueError, match=r"does not support the .* task"):
|
||||
with pytest.raises(ValueError, match=r"does not support task=.*"):
|
||||
ModelConfig(
|
||||
model_id,
|
||||
task=bad_task,
|
||||
|
||||
258
vllm/config.py
258
vllm/config.py
@ -91,24 +91,19 @@ logger = init_logger(__name__)
|
||||
ConfigT = TypeVar("ConfigT", bound=ConfigType)
|
||||
|
||||
TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
|
||||
"score", "reward", "transcription"]
|
||||
"score", "reward", "transcription", "draft"]
|
||||
|
||||
_ResolvedTask = Literal["generate", "embed", "classify", "reward", "draft",
|
||||
"transcription"]
|
||||
_ResolvedTask = Literal["generate", "transcription", "pooling", "embed",
|
||||
"classify", "reward", "draft"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft", "transcription"]
|
||||
RunnerOption = Literal["auto", "generate", "pooling", "draft"]
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
|
||||
_RUNNER_TASKS: dict[RunnerType, list[_ResolvedTask]] = {
|
||||
"generate": ["generate"],
|
||||
"pooling": ["embed", "classify", "reward"],
|
||||
"draft": ["draft"],
|
||||
"transcription": ["transcription"],
|
||||
}
|
||||
|
||||
_TASK_RUNNER: dict[_ResolvedTask, RunnerType] = {
|
||||
task: runner
|
||||
for runner, tasks in _RUNNER_TASKS.items()
|
||||
for task in tasks
|
||||
"generate": ["generate", "transcription"],
|
||||
"pooling": ["pooling", "embed", "classify", "reward"],
|
||||
"draft": [],
|
||||
}
|
||||
|
||||
|
||||
@ -234,11 +229,14 @@ class ModelConfig:
|
||||
"""Name or path of the Hugging Face model to use. It is also used as the
|
||||
content for `model_name` tag in metrics output when `served_model_name` is
|
||||
not specified."""
|
||||
task: Literal[TaskOption, Literal["draft"]] = "auto"
|
||||
"""The task to use the model for. Each vLLM instance only supports one
|
||||
task, even if the same model can be used for multiple tasks. When the model
|
||||
only supports one task, "auto" can be used to select it; otherwise, you
|
||||
must specify explicitly which task to use."""
|
||||
runner: RunnerOption = "auto"
|
||||
"""The type of model runner to use. Each vLLM instance only supports one
|
||||
model runner, even if the same model can be used for multiple types."""
|
||||
task: TaskOption = "auto"
|
||||
"""The task to use the model for. If the model supports more than one
|
||||
model runner, this is used to select which model runner to run.
|
||||
|
||||
Note that the model may support other tasks using the same model runner."""
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
@ -553,10 +551,41 @@ class ModelConfig:
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
# For pooling models, self.task is used to indicate the
|
||||
# user-selected task
|
||||
if self.task == "score":
|
||||
if self.registry.is_cross_encoder_model(self.architectures):
|
||||
self.task = "classify"
|
||||
else:
|
||||
self.task = "embed"
|
||||
elif self.task == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to 'embed', please "
|
||||
"use the new name. The old name will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
self.task = "embed"
|
||||
|
||||
all_supported_tasks = self._get_supported_tasks(self.task)
|
||||
logger.debug("Tasks supported by runner type: %s", all_supported_tasks)
|
||||
supported_runner_types = self._get_supported_runner_types(
|
||||
all_supported_tasks)
|
||||
runner_type = self._resolve_runner(self.runner, self.task,
|
||||
supported_runner_types,
|
||||
all_supported_tasks)
|
||||
|
||||
logger.debug("Selected runner type: %s", runner_type)
|
||||
# For pooling models, self.task is used to indicate the
|
||||
# user-selected task
|
||||
if runner_type == "pooling" and self.task == "auto":
|
||||
selected_task = all_supported_tasks[runner_type][-1]
|
||||
assert selected_task != "pooling"
|
||||
self.task = selected_task
|
||||
self.supported_runner_types = supported_runner_types
|
||||
self.runner_type = runner_type
|
||||
self.supported_tasks = all_supported_tasks[runner_type]
|
||||
|
||||
if self.runner_type in ("draft",
|
||||
"generate") and self.task != "transcription":
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
@ -780,11 +809,10 @@ class ModelConfig:
|
||||
f"one of {get_args(TokenizerMode)}.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
def _get_preferred_pooling_task(
|
||||
self,
|
||||
architectures: list[str],
|
||||
supported_tasks: set[_ResolvedTask],
|
||||
) -> Optional[_ResolvedTask]:
|
||||
) -> _ResolvedTask:
|
||||
model_id = self.model
|
||||
if get_pooling_config(model_id, self.revision):
|
||||
return "embed"
|
||||
@ -795,92 +823,136 @@ class ModelConfig:
|
||||
|
||||
suffix_to_preferred_task: list[tuple[str, _ResolvedTask]] = [
|
||||
# Other models follow this pattern
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ForSequenceClassification", "classify"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("EmbeddingModel", "embed"),
|
||||
("RewardModel", "reward"),
|
||||
]
|
||||
_, arch = self.registry.inspect_model_cls(architectures)
|
||||
|
||||
for suffix, pref_task in suffix_to_preferred_task:
|
||||
if arch.endswith(suffix) and pref_task in supported_tasks:
|
||||
if arch.endswith(suffix):
|
||||
return pref_task
|
||||
|
||||
return None
|
||||
return "embed"
|
||||
|
||||
def _resolve_task(
|
||||
def _get_supported_generation_tasks(
|
||||
self,
|
||||
task_option: Literal[TaskOption, Literal["draft"]],
|
||||
) -> tuple[set[_ResolvedTask], _ResolvedTask]:
|
||||
if task_option == "draft":
|
||||
return {"draft"}, "draft"
|
||||
|
||||
task_option: TaskOption,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
architectures = self.architectures
|
||||
|
||||
runner_support: dict[RunnerType, bool] = {
|
||||
# NOTE: Listed from highest to lowest priority,
|
||||
# in case the model supports multiple of them
|
||||
"transcription": registry.is_transcription_model(architectures),
|
||||
"generate": registry.is_text_generation_model(architectures),
|
||||
"pooling": registry.is_pooling_model(architectures),
|
||||
if registry.is_transcription_only_model(architectures):
|
||||
return ["transcription"]
|
||||
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if registry.is_text_generation_model(architectures):
|
||||
supported_tasks.append("generate")
|
||||
|
||||
if registry.is_transcription_model(architectures):
|
||||
supported_tasks.append("transcription")
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_supported_pooling_tasks(
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
) -> list[_ResolvedTask]:
|
||||
registry = self.registry
|
||||
architectures = self.architectures
|
||||
|
||||
supported_tasks = list[_ResolvedTask]()
|
||||
if registry.is_pooling_model(architectures):
|
||||
supported_tasks.append("pooling")
|
||||
|
||||
# For now, users must specify the task (other than "pooling")
|
||||
# to use for pooling models
|
||||
if task_option == "auto":
|
||||
preferred_task = self._get_preferred_pooling_task(
|
||||
architectures)
|
||||
|
||||
supported_tasks.append(preferred_task)
|
||||
elif task_option in _RUNNER_TASKS["pooling"]:
|
||||
supported_tasks.append(cast(_ResolvedTask, task_option))
|
||||
|
||||
return supported_tasks
|
||||
|
||||
def _get_supported_tasks(
|
||||
self,
|
||||
task_option: TaskOption,
|
||||
) -> dict[RunnerType, list[_ResolvedTask]]:
|
||||
return {
|
||||
"generate": self._get_supported_generation_tasks(task_option),
|
||||
"pooling": self._get_supported_pooling_tasks(task_option),
|
||||
"draft": ["draft"]
|
||||
}
|
||||
supported_runner_types_lst: list[RunnerType] = [
|
||||
runner_type
|
||||
for runner_type, is_supported in runner_support.items()
|
||||
if is_supported
|
||||
]
|
||||
|
||||
supported_tasks_lst: list[_ResolvedTask] = [
|
||||
task for runner_type in supported_runner_types_lst
|
||||
for task in _RUNNER_TASKS[runner_type]
|
||||
]
|
||||
supported_tasks = set(supported_tasks_lst)
|
||||
def _get_supported_runner_types(
|
||||
self,
|
||||
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
|
||||
) -> set[RunnerType]:
|
||||
return {
|
||||
runner
|
||||
for runner, runner_tasks in supported_tasks.items()
|
||||
if len(runner_tasks) > 0
|
||||
}
|
||||
|
||||
if task_option == "auto":
|
||||
selected_task = next(iter(supported_tasks_lst))
|
||||
def _resolve_runner(
|
||||
self,
|
||||
runner_option: RunnerOption,
|
||||
task_option: TaskOption,
|
||||
supported_runner_types: set[RunnerType],
|
||||
supported_tasks: dict[RunnerType, list[_ResolvedTask]],
|
||||
) -> RunnerType:
|
||||
if not supported_runner_types:
|
||||
raise ValueError("This model does not support any model runners!")
|
||||
|
||||
if len(supported_tasks_lst) > 1:
|
||||
preferred_task = self._get_preferred_task(
|
||||
architectures, supported_tasks)
|
||||
if preferred_task is not None:
|
||||
selected_task = preferred_task
|
||||
if runner_option != "auto":
|
||||
if runner_option not in supported_runner_types:
|
||||
raise ValueError(
|
||||
f"This model does not support runner={runner_option!r}. "
|
||||
f"Available runners: {supported_runner_types}")
|
||||
|
||||
logger.info(
|
||||
"This model supports multiple tasks: %s. "
|
||||
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||
else:
|
||||
if task_option == "score":
|
||||
if not runner_support["pooling"]:
|
||||
msg = (f"This model does not support the '{task_option}' "
|
||||
f"task. Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
if self.registry.is_cross_encoder_model(architectures):
|
||||
task_option = "classify"
|
||||
else:
|
||||
task_option = "embed"
|
||||
return runner_option
|
||||
|
||||
if task_option != "auto":
|
||||
for runner, runner_tasks in supported_tasks.items():
|
||||
if task_option in runner_tasks:
|
||||
return runner
|
||||
else:
|
||||
# Aliases
|
||||
if task_option == "embedding":
|
||||
msg = ("The 'embedding' task has been renamed to "
|
||||
"'embed', please use the new name. The old name "
|
||||
"will be removed in v1.0.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
task_runner: RunnerType = next(
|
||||
runner for runner, tasks in _RUNNER_TASKS.items()
|
||||
if task_option in tasks)
|
||||
raise ValueError(
|
||||
f"This model does not support task={task_option!r}. "
|
||||
f"Available tasks for runner={task_runner!r}: "
|
||||
f"{supported_tasks[task_runner]}")
|
||||
|
||||
task_option = "embed"
|
||||
suffix_to_preferred_runner: list[tuple[str, RunnerType]] = [
|
||||
("ForCausalLM", "generate"),
|
||||
("ForConditionalGeneration", "generate"),
|
||||
("ChatModel", "generate"),
|
||||
("LMHeadModel", "generate"),
|
||||
("ForSequenceClassification", "pooling"),
|
||||
("EmbeddingModel", "pooling"),
|
||||
("RewardModel", "pooling"),
|
||||
]
|
||||
_, arch = self.registry.inspect_model_cls(self.architectures)
|
||||
|
||||
if task_option not in supported_tasks:
|
||||
msg = (
|
||||
f"This model does not support the '{task_option}' task. "
|
||||
f"Supported tasks: {supported_tasks}")
|
||||
raise ValueError(msg)
|
||||
for suffix, pref_runner in suffix_to_preferred_runner:
|
||||
if arch.endswith(suffix) and pref_runner in supported_runner_types:
|
||||
return pref_runner
|
||||
|
||||
selected_task = task_option
|
||||
if "classify" in supported_tasks.get("pooling", []):
|
||||
# When multiple pooling tasks are present, default to
|
||||
# pooling (eg cross-encoder) for non-standard architectures.
|
||||
return "pooling"
|
||||
if "generate" in supported_runner_types:
|
||||
return "generate"
|
||||
if "pooling" in supported_runner_types:
|
||||
return "pooling"
|
||||
|
||||
return supported_tasks, selected_task
|
||||
raise AssertionError("This line should not be reached")
|
||||
|
||||
def _parse_quant_hf_config(self):
|
||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||
@ -1449,14 +1521,6 @@ class ModelConfig:
|
||||
def use_mla(self) -> bool:
|
||||
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
|
||||
|
||||
@property
|
||||
def supported_runner_types(self) -> set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
@property
|
||||
def runner_type(self) -> RunnerType:
|
||||
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
|
||||
|
||||
@property
|
||||
def is_v1_compatible(self) -> bool:
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
@ -2694,7 +2758,7 @@ class SpeculativeConfig:
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
task="draft",
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.
|
||||
|
||||
@ -454,20 +454,19 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
if runner_type not in ["generate", "transcription"]:
|
||||
model_config = self.llm_engine.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "generate":
|
||||
messages = [
|
||||
"LLM.generate() is only supported for (conditional) generation "
|
||||
"models (XForCausalLM, XForConditionalGeneration).",
|
||||
"LLM.generate() is only supported for generative models."
|
||||
]
|
||||
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "generate" in supported_runner_types:
|
||||
if "generate" in model_config.supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'generate' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
"Please initialize vLLM using `--task generate`.")
|
||||
"Please initialize vLLM using `--task generate` or "
|
||||
"`--task transcription`.")
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
@ -1091,13 +1090,12 @@ class LLM:
|
||||
considered legacy and may be deprecated in the future. You should
|
||||
instead pass them via the `inputs` parameter.
|
||||
"""
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
model_config = self.llm_engine.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.encode() is only supported for pooling models."]
|
||||
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
if "pooling" in model_config.supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
@ -1119,13 +1117,13 @@ class LLM:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
elif isinstance(pooling_params, PoolingParams):
|
||||
pooling_params.verify(self.llm_engine.model_config)
|
||||
pooling_params.verify(model_config)
|
||||
else:
|
||||
for pooling_param in pooling_params:
|
||||
pooling_param.verify(self.llm_engine.model_config)
|
||||
pooling_param.verify(model_config)
|
||||
|
||||
tokenization_kwargs: dict[str, Any] = {}
|
||||
_validate_truncation_size(self.llm_engine.model_config.max_model_len,
|
||||
tokenization_kwargs = dict[str, Any]()
|
||||
_validate_truncation_size(model_config.max_model_len,
|
||||
truncate_prompt_tokens, tokenization_kwargs)
|
||||
|
||||
self._validate_and_add_requests(
|
||||
@ -1178,9 +1176,10 @@ class LLM:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "embed":
|
||||
raise ValueError(
|
||||
"Embedding API is only enabled for `--task embed`")
|
||||
model_config = self.llm_engine.model_config
|
||||
if "embed" not in model_config.supported_tasks:
|
||||
raise ValueError("Embedding API is not supported by this model. "
|
||||
"Please set `--task embed`.")
|
||||
|
||||
items = self.encode(prompts,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
@ -1223,9 +1222,11 @@ class LLM:
|
||||
A list of `ClassificationRequestOutput` objects containing the
|
||||
embedding vectors in the same order as the input prompts.
|
||||
"""
|
||||
if self.llm_engine.model_config.task != "classify":
|
||||
model_config = self.llm_engine.model_config
|
||||
if "classify" not in model_config.supported_tasks:
|
||||
raise ValueError(
|
||||
"Classification API is only enabled for `--task classify`")
|
||||
"Classification API is not supported by this model. "
|
||||
"Please set `--task classify`.")
|
||||
|
||||
items = self.encode(prompts,
|
||||
use_tqdm=use_tqdm,
|
||||
@ -1392,13 +1393,12 @@ class LLM:
|
||||
A list of `ScoringRequestOutput` objects containing the
|
||||
generated scores in the same order as the input prompts.
|
||||
"""
|
||||
runner_type = self.llm_engine.model_config.runner_type
|
||||
model_config = self.llm_engine.model_config
|
||||
runner_type = model_config.runner_type
|
||||
if runner_type != "pooling":
|
||||
messages = ["LLM.score() is only supported for pooling models."]
|
||||
|
||||
supported_runner_types = self.llm_engine.model_config \
|
||||
.supported_runner_types
|
||||
if "pooling" in supported_runner_types:
|
||||
if "pooling" in model_config.supported_runner_types:
|
||||
messages.append(
|
||||
"Your model supports the 'pooling' runner, but is "
|
||||
f"currently initialized for the '{runner_type}' runner. "
|
||||
@ -1407,12 +1407,13 @@ class LLM:
|
||||
|
||||
raise ValueError(" ".join(messages))
|
||||
|
||||
if self.llm_engine.model_config.task not in ("embed", "classify"):
|
||||
raise ValueError("Score API is only enabled for "
|
||||
"`--task embed or --task classify`.")
|
||||
if all(t not in model_config.supported_tasks
|
||||
for t in ("embed", "classify")):
|
||||
raise ValueError("Score API is not supported by this model. "
|
||||
"Please set `--task embed` or `--task classify`.")
|
||||
|
||||
if (self.llm_engine.model_config.task == "classify"
|
||||
and self.llm_engine.model_config.hf_config.num_labels != 1):
|
||||
if (model_config.task == "classify"
|
||||
and getattr(model_config.hf_config, "num_labels", 0) != 1):
|
||||
raise ValueError("Score API is only enabled for num_labels == 1.")
|
||||
|
||||
# the tokenizer for models such as
|
||||
|
||||
@ -1520,7 +1520,7 @@ async def init_app_state(
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1537,7 +1537,7 @@ async def init_app_state(
|
||||
reasoning_parser=args.reasoning_parser,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1545,7 +1545,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_force_include_usage=args.enable_force_include_usage,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
state.openai_serving_pooling = OpenAIServingPooling(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1553,7 +1553,7 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.runner_type == "pooling" else None
|
||||
) if "pooling" in model_config.supported_tasks else None
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine_client,
|
||||
model_config,
|
||||
@ -1561,22 +1561,24 @@ async def init_app_state(
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
) if model_config.task == "embed" else None
|
||||
) if "embed" in model_config.supported_tasks else None
|
||||
state.openai_serving_classification = ServingClassification(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.task == "classify" else None
|
||||
) if "classify" in model_config.supported_tasks else None
|
||||
|
||||
enable_serving_reranking = (model_config.task == "classify" and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
enable_serving_reranking = ("classify" in model_config.supported_tasks
|
||||
and getattr(model_config.hf_config,
|
||||
"num_labels", 0) == 1)
|
||||
state.openai_serving_scores = ServingScores(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger) if (
|
||||
model_config.task == "embed" or enable_serving_reranking) else None
|
||||
request_logger=request_logger,
|
||||
) if ("embed" in model_config.supported_tasks
|
||||
or enable_serving_reranking) else None
|
||||
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine_client,
|
||||
@ -1591,13 +1593,13 @@ async def init_app_state(
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.runner_type == "transcription" else None
|
||||
) if "transcription" in model_config.supported_tasks else None
|
||||
state.openai_serving_translation = OpenAIServingTranslation(
|
||||
engine_client,
|
||||
model_config,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if model_config.runner_type == "transcription" else None
|
||||
) if "transcription" in model_config.supported_tasks else None
|
||||
state.task = model_config.task
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
|
||||
@ -348,7 +348,7 @@ async def main(args):
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if model_config.runner_type == "generate" else None
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
@ -356,17 +356,19 @@ async def main(args):
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
) if model_config.task == "embed" else None
|
||||
) if "embed" in model_config.supported_tasks else None
|
||||
|
||||
enable_serving_reranking = (model_config.task == "classify" and getattr(
|
||||
model_config.hf_config, "num_labels", 0) == 1)
|
||||
enable_serving_reranking = ("classify" in model_config.supported_tasks
|
||||
and getattr(model_config.hf_config,
|
||||
"num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = (ServingScores(
|
||||
openai_serving_scores = ServingScores(
|
||||
engine,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
) if (model_config.task == "embed" or enable_serving_reranking) else None)
|
||||
) if ("embed" in model_config.supported_tasks
|
||||
or enable_serving_reranking) else None
|
||||
|
||||
tracker = BatchProgressTracker()
|
||||
logger.info("Reading batch from %s...", args.input_file)
|
||||
|
||||
@ -694,6 +694,12 @@ class SupportsTranscription(Protocol):
|
||||
|
||||
supports_transcription: ClassVar[Literal[True]] = True
|
||||
|
||||
supports_transcription_only: ClassVar[bool] = False
|
||||
"""
|
||||
Transcription models can opt out of text generation by setting this to
|
||||
`True`.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
|
||||
@ -284,6 +284,7 @@ class _ModelInfo:
|
||||
is_hybrid: bool
|
||||
has_noops: bool
|
||||
supports_transcription: bool
|
||||
supports_transcription_only: bool
|
||||
supports_v0_only: bool
|
||||
|
||||
@staticmethod
|
||||
@ -299,6 +300,8 @@ class _ModelInfo:
|
||||
is_attention_free=is_attention_free(model),
|
||||
is_hybrid=is_hybrid(model),
|
||||
supports_transcription=supports_transcription(model),
|
||||
supports_transcription_only=(supports_transcription(model) and
|
||||
model.supports_transcription_only),
|
||||
supports_v0_only=supports_v0_only(model),
|
||||
has_noops=has_noops(model),
|
||||
)
|
||||
@ -573,6 +576,13 @@ class _ModelRegistry:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_transcription
|
||||
|
||||
def is_transcription_only_model(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
) -> bool:
|
||||
model_cls, _ = self.inspect_model_cls(architectures)
|
||||
return model_cls.supports_transcription_only
|
||||
|
||||
def is_v1_compatible(
|
||||
self,
|
||||
architectures: Union[str, list[str]],
|
||||
|
||||
@ -772,6 +772,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
".fc2.": ".mlp.fc2."
|
||||
})
|
||||
|
||||
# Whisper only supports audio-conditioned generation.
|
||||
supports_transcription_only = True
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
if language in ISO639_1_SUPPORTED_LANGS:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user