diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 0b0ce9828a74d..b1c9c9af6e235 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -44,8 +44,9 @@ class model_aware_kv_ops_helper: head_size = model_config.qk_nope_head_dim + \ model_config.qk_rope_head_dim else: - head_size = getattr(model_config, "head_dim", - int(hidden_size // num_attention_heads)) + head_size = getattr(model_config, "head_dim", None) + if head_size is None: + head_size = int(hidden_size // num_attention_heads) return num_heads, head_size diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4ffd06319684c..838560692bcf5 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -127,8 +127,9 @@ class ExaoneAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c49db653f735a..3524d036db222 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -122,8 +122,9 @@ class GraniteAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = config.attention_multiplier diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 0285402dadf7f..7724e52c1ce1b 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -604,8 +604,9 @@ class MiniMaxText01DecoderLayer(nn.Module): rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, @@ -861,8 +862,9 @@ class MiniMaxText01Model(nn.Module): cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) - head_dim = getattr(config, "head_dim", - config.hidden_size // config.num_attention_heads) + head_dim = getattr(config, "head_dim", None) + if head_dim is None: + head_dim = config.hidden_size // config.num_attention_heads if hasattr(config, "max_model_len") and isinstance( config.max_model_len, int): max_position_embeddings = min(config.max_position_embeddings, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 4823808e89067..9bc7a16153e1f 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -138,8 +138,9 @@ class MixtralAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index f096f6a7996dc..8220200d270c2 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -193,8 +193,9 @@ class MixtralAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MixtralConfig has an optional head_dim argument - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index c5c5155a2df56..d0999e30e1ba4 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -158,8 +158,9 @@ class NemotronAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 53e5274aa5740..fcd17cc1c2ba4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -126,8 +126,9 @@ class SolarAttention(nn.Module): assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # MistralConfig has an optional head_dim introduced by Mistral-Nemo - self.head_dim = getattr(config, "head_dim", - self.hidden_size // self.total_num_heads) + self.head_dim = getattr(config, "head_dim", None) + if self.head_dim is None: + self.head_dim = self.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5