[Misc] Make LayerBlockType a Literal instead of Enum (#27658)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-10-29 00:23:35 +08:00 committed by GitHub
parent a8c02fb5bf
commit f5710ef02a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 24 deletions

View File

@ -41,7 +41,6 @@ from vllm.transformers_utils.config import (
)
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import LayerBlockType
from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype
@ -91,6 +90,7 @@ LogprobsMode = Literal[
]
HfOverrides = dict[str, Any] | Callable[[PretrainedConfig], PretrainedConfig]
ModelImpl = Literal["auto", "vllm", "transformers", "terratorch"]
LayerBlockType = Literal["attention", "linear_attention", "mamba"]
_RUNNER_TASKS: dict[RunnerType, list[TaskOption]] = {
"generate": ["generate", "transcription"],
@ -1433,11 +1433,11 @@ class ModelConfig:
def get_num_layers_by_block_type(
self,
parallel_config: ParallelConfig,
block_type: LayerBlockType = LayerBlockType.attention,
block_type: LayerBlockType = "attention",
) -> int:
# This function relies on 'layers_block_type' in hf_config,
# for w/o this attribute, we will need to have workarounds like so
attn_block_type = block_type == LayerBlockType.attention
attn_block_type = block_type == "attention"
is_transformer = (
not self.is_hybrid and not self.has_noops and not self.is_attention_free
)
@ -1469,9 +1469,7 @@ class ModelConfig:
)
else:
return self.get_num_layers(parallel_config)
return sum(
t == block_type.value for t in layers_block_type_value[start:end]
)
return sum(t == block_type for t in layers_block_type_value[start:end])
# Hybrid model Minimax
attn_type_list = getattr(self.hf_config, "attn_type_list", None)
@ -1481,19 +1479,16 @@ class ModelConfig:
# Hybrid model Qwen3Next
layer_types_value = getattr(self.hf_config, "layer_types", None)
if layer_types_value is not None:
if getattr(block_type, "value", block_type) == "attention":
if block_type == "attention":
return sum(
t == "full_attention" for t in layer_types_value[start:end]
)
elif getattr(block_type, "value", block_type) == "linear_attention":
elif block_type == "linear_attention":
return sum(
t == "linear_attention" for t in layer_types_value[start:end]
)
else:
return sum(
t == getattr(block_type, "value", block_type)
for t in layer_types_value[start:end]
)
return sum(t == block_type for t in layer_types_value[start:end])
if (
layers_block_type_value is None
@ -1501,10 +1496,9 @@ class ModelConfig:
and layer_types_value is None
):
raise ValueError(
"The model is an hybrid without a"
"layers_block_type or an attn_type_list, or a layer_types "
"in the hf_config, cannot determine the num of "
f"{block_type.value} layers"
"The model is an hybrid without a layers_block_type or an "
"attn_type_list, or a layer_types in the hf_config, "
f"cannot determine the num of {block_type} layers"
)
def get_mamba_chunk_size(self) -> int | None:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import enum
import inspect
import uuid
import warnings
@ -67,11 +66,6 @@ STR_INVALID_VAL: str = "INVALID"
T = TypeVar("T")
class LayerBlockType(enum.Enum):
attention = "attention"
mamba = "mamba"
def random_uuid() -> str:
return str(uuid.uuid4().hex)

View File

@ -53,7 +53,6 @@ from vllm.multimodal.inputs import (
from vllm.multimodal.utils import group_mm_kwargs_by_modality
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import LayerBlockType
from vllm.utils.math_utils import cdiv, prev_power_of_2
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.pallas import (
@ -212,7 +211,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Model-related.
self.num_attn_layers = model_config.get_num_layers_by_block_type(
parallel_config, LayerBlockType.attention
parallel_config, "attention"
)
self.num_query_heads = model_config.get_num_attention_heads(parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)