mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 18:35:40 +08:00
Add util function for checking nesting of rope parameters (#31146)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
769f27e701
commit
b10d47e0e0
@ -11,7 +11,6 @@ import torch
|
|||||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
@ -29,6 +28,7 @@ from vllm.transformers_utils.config import (
|
|||||||
get_pooling_config,
|
get_pooling_config,
|
||||||
get_sentence_transformer_tokenizer_config,
|
get_sentence_transformer_tokenizer_config,
|
||||||
is_encoder_decoder,
|
is_encoder_decoder,
|
||||||
|
is_rope_parameters_nested,
|
||||||
try_get_dense_modules,
|
try_get_dense_modules,
|
||||||
try_get_generation_config,
|
try_get_generation_config,
|
||||||
try_get_safetensors_metadata,
|
try_get_safetensors_metadata,
|
||||||
@ -2125,9 +2125,7 @@ def _get_and_verify_max_len(
|
|||||||
# In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
|
# In Transformers v5 rope_parameters could be TypedDict or dict[str, TypedDict].
|
||||||
# To simplify the verification, we convert it to dict[str, TypedDict].
|
# To simplify the verification, we convert it to dict[str, TypedDict].
|
||||||
rope_parameters = getattr(hf_config, "rope_parameters", None)
|
rope_parameters = getattr(hf_config, "rope_parameters", None)
|
||||||
if rope_parameters and not set(rope_parameters.keys()).issubset(
|
if rope_parameters and not is_rope_parameters_nested(rope_parameters):
|
||||||
ALLOWED_LAYER_TYPES
|
|
||||||
):
|
|
||||||
rope_parameters = {"": rope_parameters}
|
rope_parameters = {"": rope_parameters}
|
||||||
|
|
||||||
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
|
# NOTE(woosuk): Gemma3's max_model_len (128K) is already scaled by RoPE
|
||||||
|
|||||||
@ -22,7 +22,6 @@ from typing import TYPE_CHECKING, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
|
||||||
|
|
||||||
from vllm.config.utils import getattr_iter
|
from vllm.config.utils import getattr_iter
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -32,6 +31,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
|
from vllm.transformers_utils.config import is_rope_parameters_nested
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -207,7 +207,7 @@ def can_enable_torch_compile(vllm_config: "VllmConfig") -> bool:
|
|||||||
rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
|
rope_parameters: dict | None = getattr(text_config, "rope_parameters", None) or {}
|
||||||
if rope_parameters:
|
if rope_parameters:
|
||||||
# Nest rope_parameters if not nested already to simplify logic
|
# Nest rope_parameters if not nested already to simplify logic
|
||||||
if not set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
if not is_rope_parameters_nested(rope_parameters):
|
||||||
rope_parameters = {"": rope_parameters}
|
rope_parameters = {"": rope_parameters}
|
||||||
return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
|
return all(rp["rope_type"] != "dynamic" for rp in rope_parameters.values())
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from huggingface_hub import (
|
|||||||
)
|
)
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from transformers import GenerationConfig, PretrainedConfig
|
from transformers import GenerationConfig, PretrainedConfig
|
||||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
|
||||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||||
from transformers.models.auto.modeling_auto import (
|
from transformers.models.auto.modeling_auto import (
|
||||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||||
@ -44,6 +43,16 @@ from .repo_utils import (
|
|||||||
with_retry,
|
with_retry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Transformers v5
|
||||||
|
from transformers.configuration_utils import ALLOWED_ATTENTION_LAYER_TYPES
|
||||||
|
except ImportError:
|
||||||
|
# Transformers v4
|
||||||
|
from transformers.configuration_utils import (
|
||||||
|
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
from modelscope import AutoConfig
|
from modelscope import AutoConfig
|
||||||
else:
|
else:
|
||||||
@ -104,6 +113,14 @@ _AUTO_CONFIG_KWARGS_OVERRIDES: dict[str, dict[str, Any]] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
|
||||||
|
"""Check if rope_parameters is nested by layer types."""
|
||||||
|
# Cannot be nested if rope_parameters is empty
|
||||||
|
if not rope_parameters:
|
||||||
|
return False
|
||||||
|
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
|
||||||
|
|
||||||
|
|
||||||
class HFConfigParser(ConfigParserBase):
|
class HFConfigParser(ConfigParserBase):
|
||||||
def parse(
|
def parse(
|
||||||
self,
|
self,
|
||||||
@ -346,7 +363,7 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
|||||||
config.rope_parameters["original_max_position_embeddings"] = ompe
|
config.rope_parameters["original_max_position_embeddings"] = ompe
|
||||||
|
|
||||||
# Handle nested rope_parameters in interleaved sliding attention models
|
# Handle nested rope_parameters in interleaved sliding attention models
|
||||||
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
if is_rope_parameters_nested(config.rope_parameters):
|
||||||
for rope_parameters_layer_type in config.rope_parameters.values():
|
for rope_parameters_layer_type in config.rope_parameters.values():
|
||||||
patch_rope_parameters_dict(rope_parameters_layer_type)
|
patch_rope_parameters_dict(rope_parameters_layer_type)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user