Fix head_dim not existing in all model configs (Transformers backend) (#14141)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-03 17:48:11 +00:00 committed by GitHub
parent 848a6438ae
commit 91373a0d15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -25,7 +25,6 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
@ -128,10 +127,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
@ -145,15 +146,17 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
self.apply_base_model_tp_plan(self.model)
# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
num_heads = model_config.get_num_attention_heads(parallel_config)
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [
Attention(
num_heads=divide(config.num_attention_heads, tp_size),
head_size=config.head_dim,
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
@ -163,7 +166,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
@ -172,7 +175,7 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.vocab_size, logit_scale)
self.sampler = get_sampler()
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
@ -203,12 +206,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
self.model.set_input_embeddings(new_module)
module.set_input_embeddings(new_module)
def forward(
self,