mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:54:56 +08:00
[Model] Support Mistral-Nemo (#6548)
This commit is contained in:
parent
ecdb462c24
commit
15c6a079b1
@ -89,6 +89,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
@ -115,7 +116,9 @@ class LlamaAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
@ -189,6 +192,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = LlamaAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user