mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 15:15:45 +08:00
[Optimization] Avoid repeated model architecture conversion for pooling models (#25261)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
f91480b2d4
commit
c60e6137f0
@ -322,8 +322,28 @@ class ModelConfig:
|
|||||||
factors.append(self.override_generation_config)
|
factors.append(self.override_generation_config)
|
||||||
factors.append(self.rope_scaling)
|
factors.append(self.rope_scaling)
|
||||||
factors.append(self.rope_theta)
|
factors.append(self.rope_theta)
|
||||||
|
|
||||||
# hf_config can control how the model looks!
|
# hf_config can control how the model looks!
|
||||||
factors.append(self.hf_config.to_json_string())
|
try:
|
||||||
|
hf_config_json = self.hf_config.to_json_string(use_diff=False)
|
||||||
|
except TypeError:
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.utils.jsontree import json_map_leaves
|
||||||
|
|
||||||
|
# Handle nested HF configs with unserializable values gracefully
|
||||||
|
hf_config_json = json.dumps(
|
||||||
|
json_map_leaves(
|
||||||
|
lambda v: v.to_dict()
|
||||||
|
if isinstance(v, PretrainedConfig) else str(v),
|
||||||
|
self.hf_config.to_dict(),
|
||||||
|
),
|
||||||
|
indent=2,
|
||||||
|
sort_keys=True,
|
||||||
|
) + "\n"
|
||||||
|
|
||||||
|
factors.append(hf_config_json)
|
||||||
|
|
||||||
str_factors = str(factors)
|
str_factors = str(factors)
|
||||||
assert_hashable(str_factors)
|
assert_hashable(str_factors)
|
||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|||||||
@ -165,7 +165,11 @@ def device_loading_context(module: torch.nn.Module,
|
|||||||
# New parameters or parameters already on target device are untouched
|
# New parameters or parameters already on target device are untouched
|
||||||
|
|
||||||
|
|
||||||
def get_model_architecture(
|
_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]()
|
||||||
|
"""Caches the outputs of `_get_model_architecture`."""
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_architecture(
|
||||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||||
|
|
||||||
@ -209,6 +213,17 @@ def get_model_architecture(
|
|||||||
return model_cls, arch
|
return model_cls, arch
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_architecture(
|
||||||
|
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||||
|
key = model_config.compute_hash()
|
||||||
|
if key in _MODEL_ARCH_BY_HASH:
|
||||||
|
return _MODEL_ARCH_BY_HASH[key]
|
||||||
|
|
||||||
|
model_arch = _get_model_architecture(model_config)
|
||||||
|
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||||
|
return model_arch
|
||||||
|
|
||||||
|
|
||||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||||
return get_model_architecture(model_config)[0]
|
return get_model_architecture(model_config)[0]
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user