mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 09:45:30 +08:00
[Bugfix] config.head_dim is now explicitly set to None (#18432)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
parent
3b17ea26e4
commit
0c15c2e486
@ -44,8 +44,9 @@ class model_aware_kv_ops_helper:
|
||||
head_size = model_config.qk_nope_head_dim + \
|
||||
model_config.qk_rope_head_dim
|
||||
else:
|
||||
head_size = getattr(model_config, "head_dim",
|
||||
int(hidden_size // num_attention_heads))
|
||||
head_size = getattr(model_config, "head_dim", None)
|
||||
if head_size is None:
|
||||
head_size = int(hidden_size // num_attention_heads)
|
||||
|
||||
return num_heads, head_size
|
||||
|
||||
|
||||
@ -127,8 +127,9 @@ class ExaoneAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
@ -122,8 +122,9 @@ class GraniteAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = config.attention_multiplier
|
||||
|
||||
@ -604,8 +604,9 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
|
||||
head_dim = getattr(config, "head_dim",
|
||||
config.hidden_size // config.num_attention_heads)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if hasattr(config, "max_model_len") and isinstance(
|
||||
config.max_model_len, int):
|
||||
max_position_embeddings = min(config.max_position_embeddings,
|
||||
@ -861,8 +862,9 @@ class MiniMaxText01Model(nn.Module):
|
||||
cache_shape=self.cache_shape)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
head_dim = getattr(config, "head_dim",
|
||||
config.hidden_size // config.num_attention_heads)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if hasattr(config, "max_model_len") and isinstance(
|
||||
config.max_model_len, int):
|
||||
max_position_embeddings = min(config.max_position_embeddings,
|
||||
|
||||
@ -138,8 +138,9 @@ class MixtralAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MixtralConfig has an optional head_dim argument
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
@ -193,8 +193,9 @@ class MixtralAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MixtralConfig has an optional head_dim argument
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
@ -158,8 +158,9 @@ class NemotronAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
@ -126,8 +126,9 @@ class SolarAttention(nn.Module):
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.head_dim = getattr(config, "head_dim", None)
|
||||
if self.head_dim is None:
|
||||
self.head_dim = self.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user