mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[UX] Support nested dicts in hf_overrides (#25727)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
2111b4643c
commit
c6873c4e6d
@ -292,6 +292,37 @@ def test_rope_customization():
|
|||||||
assert longchat_model_config.max_model_len == 4096
|
assert longchat_model_config.max_model_len == 4096
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_hf_overrides():
|
||||||
|
"""Test that nested hf_overrides work correctly."""
|
||||||
|
# Test with a model that has text_config
|
||||||
|
model_config = ModelConfig(
|
||||||
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
hf_overrides={
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 1024,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert model_config.hf_config.text_config.hidden_size == 1024
|
||||||
|
|
||||||
|
# Test with deeply nested overrides
|
||||||
|
model_config = ModelConfig(
|
||||||
|
"Qwen/Qwen2-VL-2B-Instruct",
|
||||||
|
hf_overrides={
|
||||||
|
"text_config": {
|
||||||
|
"hidden_size": 2048,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
},
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 512,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert model_config.hf_config.text_config.hidden_size == 2048
|
||||||
|
assert model_config.hf_config.text_config.num_attention_heads == 16
|
||||||
|
assert model_config.hf_config.vision_config.hidden_size == 512
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
|
current_platform.is_rocm(), reason="Encoder Decoder models not supported on ROCm."
|
||||||
)
|
)
|
||||||
|
|||||||
@ -367,6 +367,51 @@ class ModelConfig:
|
|||||||
assert_hashable(str_factors)
|
assert_hashable(str_factors)
|
||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
|
def _update_nested(
|
||||||
|
self,
|
||||||
|
target: Union["PretrainedConfig", dict[str, Any]],
|
||||||
|
updates: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Recursively updates a config or dict with nested updates."""
|
||||||
|
for key, value in updates.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
# Get the nested target
|
||||||
|
if isinstance(target, dict):
|
||||||
|
nested_target = target.get(key)
|
||||||
|
else:
|
||||||
|
nested_target = getattr(target, key, None)
|
||||||
|
|
||||||
|
# If nested target exists and can be updated recursively
|
||||||
|
if nested_target is not None and (
|
||||||
|
isinstance(nested_target, dict)
|
||||||
|
or hasattr(nested_target, "__dict__")
|
||||||
|
):
|
||||||
|
self._update_nested(nested_target, value)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Set the value (base case)
|
||||||
|
if isinstance(target, dict):
|
||||||
|
target[key] = value
|
||||||
|
else:
|
||||||
|
setattr(target, key, value)
|
||||||
|
|
||||||
|
def _apply_dict_overrides(
|
||||||
|
self,
|
||||||
|
config: "PretrainedConfig",
|
||||||
|
overrides: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Apply dict overrides, handling both nested configs and dict values."""
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
for key, value in overrides.items():
|
||||||
|
attr = getattr(config, key, None)
|
||||||
|
if attr is not None and isinstance(attr, PretrainedConfig):
|
||||||
|
# It's a nested config - recursively update it
|
||||||
|
self._update_nested(attr, value)
|
||||||
|
else:
|
||||||
|
# It's a dict-valued parameter - set it directly
|
||||||
|
setattr(config, key, value)
|
||||||
|
|
||||||
def __post_init__(
|
def __post_init__(
|
||||||
self,
|
self,
|
||||||
# Multimodal config init vars
|
# Multimodal config init vars
|
||||||
@ -419,8 +464,17 @@ class ModelConfig:
|
|||||||
if callable(self.hf_overrides):
|
if callable(self.hf_overrides):
|
||||||
hf_overrides_kw = {}
|
hf_overrides_kw = {}
|
||||||
hf_overrides_fn = self.hf_overrides
|
hf_overrides_fn = self.hf_overrides
|
||||||
|
dict_overrides: dict[str, Any] = {}
|
||||||
else:
|
else:
|
||||||
hf_overrides_kw = self.hf_overrides
|
# Separate dict overrides from flat ones
|
||||||
|
# We'll determine how to apply dict overrides after loading the config
|
||||||
|
hf_overrides_kw = {}
|
||||||
|
dict_overrides = {}
|
||||||
|
for key, value in self.hf_overrides.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
dict_overrides[key] = value
|
||||||
|
else:
|
||||||
|
hf_overrides_kw[key] = value
|
||||||
hf_overrides_fn = None
|
hf_overrides_fn = None
|
||||||
|
|
||||||
if self.rope_scaling:
|
if self.rope_scaling:
|
||||||
@ -478,6 +532,8 @@ class ModelConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.hf_config = hf_config
|
self.hf_config = hf_config
|
||||||
|
if dict_overrides:
|
||||||
|
self._apply_dict_overrides(hf_config, dict_overrides)
|
||||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||||
self.attention_chunk_size = getattr(
|
self.attention_chunk_size = getattr(
|
||||||
self.hf_text_config, "attention_chunk_size", None
|
self.hf_text_config, "attention_chunk_size", None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user