mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 15:34:29 +08:00
[Bugfix] Fix internlm2 tensor parallel inference (#8055)
This commit is contained in:
parent
4ca65a9763
commit
dd2a6a82e3
@ -1,4 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -7,7 +8,10 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.attention import Attention, AttentionMetadata
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
from vllm.config import CacheConfig
|
from vllm.config import CacheConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
split_tensor_along_last_dim,
|
||||||
|
tensor_model_parallel_all_gather)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -70,20 +74,21 @@ class InternLM2Attention(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % self.tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // self.tp_size
|
||||||
self.total_num_kv_heads = num_kv_heads
|
self.total_num_kv_heads = num_kv_heads
|
||||||
if self.total_num_kv_heads >= tp_size:
|
if self.total_num_kv_heads >= self.tp_size:
|
||||||
# Number of KV heads is greater than TP size, so we partition
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
assert self.total_num_kv_heads % self.tp_size == 0
|
||||||
else:
|
else:
|
||||||
# Number of KV heads is less than TP size, so we replicate
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
# the KV heads across multiple tensor parallel GPUs.
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
assert tp_size % self.total_num_kv_heads == 0
|
assert self.tp_size % self.total_num_kv_heads == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
||||||
self.head_dim = hidden_size // self.total_num_heads
|
self.head_dim = hidden_size // self.total_num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
@ -122,11 +127,27 @@ class InternLM2Attention(nn.Module):
|
|||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor):
|
def split_qkv(self, qkv: torch.Tensor):
|
||||||
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
|
seq_len = qkv.shape[0]
|
||||||
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
|
if self.tp_size > 1:
|
||||||
q = q.reshape(-1, self.q_size)
|
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
|
||||||
k = k.reshape(-1, self.kv_size)
|
qkv = tensor_model_parallel_all_gather(qkv)
|
||||||
v = v.reshape(-1, self.kv_size)
|
qkv = torch.split(qkv, qkv_map, dim=-1)
|
||||||
|
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
|
||||||
|
qkv = torch.cat(qkv, dim=-1)
|
||||||
|
|
||||||
|
qkv = qkv.view(seq_len, self.total_num_kv_heads,
|
||||||
|
self.key_value_groups + 2, self.head_dim)
|
||||||
|
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
|
||||||
|
q = q.reshape(seq_len, self.q_size * self.tp_size)
|
||||||
|
k = k.reshape(seq_len, self.kv_size * self.tp_size)
|
||||||
|
v = v.reshape(seq_len, self.kv_size * self.tp_size)
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
splitter = partial(split_tensor_along_last_dim,
|
||||||
|
num_partitions=self.tp_size)
|
||||||
|
q = splitter(q)[self.tp_rank]
|
||||||
|
k = splitter(k)[self.tp_rank]
|
||||||
|
v = splitter(v)[self.tp_rank]
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user