mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-28 00:53:40 +08:00
TP/quantization/weight loading refactor part 2 - Refactor quantized linear logic and extend quantization support to all models (#1622)
Refactor the tensor parallelism, quantization, and weight-loading codes. Summary of the new features enabled by this PR: - **All models** are able to be quantized with AWQ and SqueezeLLM, and [soon GPTQ](https://github.com/vllm-project/vllm/pull/1580). - Model loading code became much simpler. - Support model parallelism for all MQA/GQA models when the number of key/value heads is smaller than the tensor parallel size.
This commit is contained in:
parent
660a7fcfa4
commit
7076fa1c9f
@ -140,8 +140,8 @@ class ModelConfig:
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU worker."""
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
@ -155,23 +155,34 @@ class ModelConfig:
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
# For Falcon:
|
||||
if getattr(self.hf_config, "n_head_kv", None) is not None:
|
||||
return (self.hf_config.n_head_kv //
|
||||
parallel_config.tensor_parallel_size)
|
||||
if getattr(self.hf_config, "num_kv_heads", None) is not None:
|
||||
return (self.hf_config.num_kv_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For LLaMA-2:
|
||||
if getattr(self.hf_config, "num_key_value_heads", None) is not None:
|
||||
return (self.hf_config.num_key_value_heads //
|
||||
parallel_config.tensor_parallel_size)
|
||||
# For ChatGLM-2:
|
||||
if getattr(self.hf_config, "multi_query_group_num", None) is not None:
|
||||
return (self.hf_config.multi_query_group_num //
|
||||
parallel_config.tensor_parallel_size)
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
return total_num_attention_heads // parallel_config.tensor_parallel_size
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
"num_kv_heads",
|
||||
# For LLaMA-2:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
|
||||
@ -142,10 +142,10 @@ class RequestTracker:
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]:
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[dict] = []
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
|
||||
541
vllm/model_executor/layers/linear.py
Normal file
541
vllm/model_executor/layers/linear.py
Normal file
@ -0,0 +1,541 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.parallel_utils.utils import (
|
||||
divide, split_tensor_along_last_dim)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
"""Create weights for a linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights to the input tensor."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
|
||||
Args:
|
||||
separate_bias_add: If true, add bias separately after matrix
|
||||
multiplication.
|
||||
"""
|
||||
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
weight = Parameter(torch.empty(output_size,
|
||||
input_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
return {"weight": weight}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = weights["weight"]
|
||||
if self.separate_bias_add:
|
||||
if bias:
|
||||
return F.linear(x, weight) + bias
|
||||
return F.linear(x, weight)
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
class ReplicatedLinear(torch.nn.Module):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Args:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
param_data = param.data
|
||||
if output_dim is not None:
|
||||
shard_size = param_data.shape[output_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""Packed linear layers with column parallelism.
|
||||
|
||||
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
||||
along the output dimension. When the weight matrix is loaded, the
|
||||
different partitions are sharded separately.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_sizes: list of output dimensions of the linear layer.
|
||||
bias: If true, add bias.
|
||||
gather_output: If true, call all-gather on output and make the output
|
||||
available to all GPUs, otherwise, every GPU will have
|
||||
its own output.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, linear_method)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
current_shard_offset = 0
|
||||
shard_offsets = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
"""Linear layers for the attention's QKV transformation.
|
||||
|
||||
Linear layers for the linear transformation of the query, key, and value
|
||||
vectors in the attention layer. The weight matrix is concatenated along
|
||||
the output dimension. The layer is parallelized along the head dimension.
|
||||
When the number of key/value heads is smaller than the number of query
|
||||
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
||||
be replicated while the query heads are partitioned.
|
||||
|
||||
Args:
|
||||
hidden_size: input hidden state size of the transformer.
|
||||
head_size: size of each attention head.
|
||||
total_num_heads: total number of attention query heads.
|
||||
total_num_kv_heads: total number of attention key/value heads. If
|
||||
None, assume total_num_kv_heads = total_num_heads.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
if tp_size >= self.total_num_kv_heads:
|
||||
self.num_kv_heads = 1
|
||||
self.num_kv_head_replicas = divide(tp_size,
|
||||
self.total_num_kv_heads)
|
||||
else:
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
self.num_kv_head_replicas = 1
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, linear_method)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
return
|
||||
shard_offsets = [
|
||||
# (shard_id, shard_offset, shard_size)
|
||||
("q", 0, self.total_num_heads * self.head_size),
|
||||
("k", self.total_num_heads * self.head_size,
|
||||
self.total_num_kv_heads * self.head_size),
|
||||
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
||||
self.head_size, self.total_num_kv_heads * self.head_size),
|
||||
]
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
return
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if output_dim is not None:
|
||||
if loaded_shard_id == "q":
|
||||
shard_offset = 0
|
||||
shard_size = self.num_heads * self.head_size
|
||||
elif loaded_shard_id == "k":
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
elif loaded_shard_id == "v":
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
shard_id = tp_rank // self.num_kv_head_replicas
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
skip_bias_add: This was added to enable performance optimization where
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
param_data = param.data
|
||||
if input_dim is not None:
|
||||
shard_size = param_data.shape[input_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
||||
22
vllm/model_executor/layers/quantization/__init__.py
Normal file
22
vllm/model_executor/layers/quantization/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
]
|
||||
155
vllm/model_executor/layers/quantization/awq.py
Normal file
155
vllm/model_executor/layers/quantization/awq.py
Normal file
@ -0,0 +1,155 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
def get_linear_method(self) -> "AWQLinearMethod":
|
||||
return AWQLinearMethod(self)
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
if output_size % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
qzeros = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qzeros, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.group_size,
|
||||
output_size,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
qzeros = weights["qzeros"]
|
||||
scales = weights["scales"]
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
56
vllm/model_executor/layers/quantization/base_config.py
Normal file
56
vllm/model_executor/layers/quantization/base_config.py
Normal file
@ -0,0 +1,56 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
"""Base class for quantization configs."""
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_min_capability(self) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
raise NotImplementedError
|
||||
121
vllm/model_executor/layers/quantization/squeezellm.py
Normal file
121
vllm/model_executor/layers/quantization/squeezellm.py
Normal file
@ -0,0 +1,121 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
"""Config class for SqueezeLLM.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2306.07629
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "squeezellm"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
return 70
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
|
||||
|
||||
class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SqueezeLLM.
|
||||
|
||||
Args:
|
||||
quant_config: The SqueezeLLM quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
if input_size % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size // self.quant_config.pack_factor,
|
||||
output_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight, {
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 0,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
})
|
||||
lookup_table = Parameter(
|
||||
torch.empty(
|
||||
output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(lookup_table, {
|
||||
"output_dim": 0,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"lookup_table": lookup_table,
|
||||
}
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, qweight, out,
|
||||
lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
@ -1,41 +0,0 @@
|
||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||
from vllm.model_executor.layers.quantized_linear.squeezellm import (
|
||||
SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
_QUANTIZED_LINEAR_REGISTRY = {
|
||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||
"squeezellm":
|
||||
(SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear),
|
||||
}
|
||||
|
||||
|
||||
class ParallelLinear:
|
||||
|
||||
@classmethod
|
||||
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return ColumnParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||
return quant_linear_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return RowParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||
return quant_linear_cls(*args, **kwargs)
|
||||
@ -1,106 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
|
||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.group_size == 0
|
||||
if self.output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class AWQRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.output_size % self.quant_config.pack_factor == 0
|
||||
if self.input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
return out.reshape(out_shape)
|
||||
@ -1,84 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
|
||||
class SqueezeLLMColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.pack_factor == 0
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.pack_factor,
|
||||
self.output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lookup_table = Parameter(
|
||||
torch.empty(
|
||||
self.output_size_per_partition,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||
self.lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class SqueezeLLMRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
if self.input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The tensor parallel size is not aligned with the quantized "
|
||||
"weight shape. Please use a different tensor parallel size.")
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.pack_factor,
|
||||
self.output_size,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.lookup_table = Parameter(
|
||||
torch.empty(
|
||||
self.output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (self.qweight.shape[-1], )
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out,
|
||||
self.lookup_table)
|
||||
return out.reshape(out_shape)
|
||||
139
vllm/model_executor/layers/vocab_parallel_embedding.py
Normal file
139
vllm/model_executor/layers/vocab_parallel_embedding.py
Normal file
@ -0,0 +1,139 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.utils import divide
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
|
||||
"""Pad the vocab size to the given value."""
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
|
||||
rank: int) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
|
||||
world_size: int) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
|
||||
self.tp_size))
|
||||
self.num_embeddings_per_partition = (self.vocab_end_index -
|
||||
self.vocab_start_index)
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.weight, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
parallel_dim = param.parallel_dim
|
||||
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
|
||||
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
||||
vocab_end_index]
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = ((input_ < self.vocab_start_index) |
|
||||
(input_ >= self.vocab_end_index))
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
"""Parallelized LM head.
|
||||
|
||||
Output logits weight matrices used in the Sampler. The weight and bias
|
||||
tensors are padded to make sure they are divisible by the number of
|
||||
model parallel GPUs.
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
bias: whether to use bias.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__(num_embeddings, embedding_dim, params_dtype)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"parallel_dim": 0,
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, input_):
|
||||
del input_
|
||||
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
||||
@ -37,13 +37,6 @@ _MODEL_REGISTRY = {
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
# FIXME(woosuk): Remove this once all models support quantization.
|
||||
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||
LlamaForCausalLM,
|
||||
MistralForCausalLM,
|
||||
YiForCausalLM,
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -67,12 +60,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
|
||||
# Get the quantization config.
|
||||
quant_config = None
|
||||
# Get the (maybe quantized) linear method.
|
||||
linear_method = None
|
||||
if model_config.quantization is not None:
|
||||
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
raise ValueError(
|
||||
f"Quantization is not supported for {model_class}.")
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.download_dir)
|
||||
@ -90,14 +80,12 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
linear_method = quant_config.get_linear_method()
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
model = model_class(model_config.hf_config, quant_config)
|
||||
else:
|
||||
model = model_class(model_config.hf_config)
|
||||
model = model_class(model_config.hf_config, linear_method)
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
|
||||
@ -33,15 +33,17 @@ from torch import nn
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||
|
||||
@ -55,20 +57,17 @@ class AquilaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -111,6 +110,7 @@ class AquilaAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -128,29 +128,29 @@ class AquilaAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=self.rope_theta,
|
||||
max_position=self.max_position_embeddings,
|
||||
rotary_dim=self.head_dim,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
rope_scaling=rope_scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -171,7 +171,11 @@ class AquilaAttention(nn.Module):
|
||||
|
||||
class AquilaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: AquilaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -185,11 +189,13 @@ class AquilaDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_scaling=rope_scaling,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = AquilaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = AquilaRMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -226,19 +232,22 @@ class AquilaDecoderLayer(nn.Module):
|
||||
|
||||
class AquilaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: AquilaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: AquilaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
#vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
AquilaDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -271,17 +280,16 @@ class AquilaModel(nn.Module):
|
||||
|
||||
class AquilaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = AquilaModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.linear_method = linear_method
|
||||
self.model = AquilaModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -298,79 +306,33 @@ class AquilaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -30,18 +30,20 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE,
|
||||
PagedAttentionWithALiBi)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
@ -80,20 +82,17 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -116,6 +115,7 @@ class BaiChuanAttention(nn.Module):
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -131,17 +131,19 @@ class BaiChuanAttention(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.W_pack = ColumnParallelLinear(
|
||||
self.W_pack = QKVParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@ -188,7 +190,10 @@ class BaiChuanAttention(nn.Module):
|
||||
|
||||
class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -200,11 +205,13 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -241,7 +248,10 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
|
||||
class BaiChuanModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BaiChuanConfig, position_embedding: str):
|
||||
def __init__(self,
|
||||
config: BaiChuanConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -252,7 +262,7 @@ class BaiChuanModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -285,16 +295,15 @@ class BaiChuanModel(nn.Module):
|
||||
|
||||
class BaiChuanBaseForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config, position_embedding: str):
|
||||
def __init__(self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -311,79 +320,46 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "W_pack" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ALIBI")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config, "ROPE")
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
|
||||
@ -30,14 +30,17 @@ from transformers import BloomConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
|
||||
class BloomAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.total_num_heads = config.n_head
|
||||
@ -81,17 +88,18 @@ class BloomAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
@ -125,19 +133,23 @@ class BloomAttention(nn.Module):
|
||||
|
||||
class BloomMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -149,16 +161,20 @@ class BloomMLP(nn.Module):
|
||||
|
||||
class BloomBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config)
|
||||
self.self_attention = BloomAttention(config, linear_method)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config)
|
||||
self.mlp = BloomMLP(config, linear_method)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
@ -203,7 +219,11 @@ class BloomBlock(nn.Module):
|
||||
|
||||
class BloomModel(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
@ -216,8 +236,10 @@ class BloomModel(nn.Module):
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList(
|
||||
[BloomBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([
|
||||
BloomBlock(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final Layer Norm
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -251,12 +273,15 @@ class BloomModel(nn.Module):
|
||||
|
||||
class BloomForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: BloomConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = BloomModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.transformer = BloomModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -274,55 +299,36 @@ class BloomForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if name == "lm_head.weight":
|
||||
# Since hidden_states are parallelized, we need to
|
||||
# load lm_head.weight in parallel.
|
||||
self._column_parallel_weights.append(name)
|
||||
# If lm_head is provided, use it instead.
|
||||
param = self.lm_head_weight
|
||||
else:
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = state_dict[name]
|
||||
continue
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): BLOOM's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# NOTE: BLOOM's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
start = shard_size * tp_rank
|
||||
end = shard_size * (tp_rank + 1)
|
||||
loaded_weight = loaded_weight[start:end]
|
||||
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -6,32 +6,28 @@
|
||||
The input of the model is flattened to a 1D tensor of tokens. The model uses
|
||||
InputMetadata to extract the original 2D shape of the input.
|
||||
"""
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.parallel_utils.layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -39,7 +35,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GLMAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -50,25 +50,33 @@ class GLMAttention(nn.Module):
|
||||
self.total_num_kv_heads = (config.multi_query_group_num
|
||||
if config.multi_query_attention else
|
||||
config.num_attention_heads)
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
bias=config.add_qkv_bias,
|
||||
gather_output=False,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
@ -78,7 +86,6 @@ class GLMAttention(nn.Module):
|
||||
rotary_dim=self.head_dim // 2,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
is_neox_style=False,
|
||||
# is_glm_style=True
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -117,17 +124,21 @@ class GLMMLP(nn.Module):
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.add_bias = config.add_bias_linear
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
self.dense_h_to_4h = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.ffn_hidden_size * 2,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
@ -137,7 +148,7 @@ class GLMMLP(nn.Module):
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -159,6 +170,7 @@ class GLMBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
@ -172,7 +184,7 @@ class GLMBlock(nn.Module):
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config)
|
||||
self.self_attention = GLMAttention(config, linear_method)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
@ -180,7 +192,7 @@ class GLMBlock(nn.Module):
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config)
|
||||
self.mlp = GLMMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -227,7 +239,11 @@ class GLMBlock(nn.Module):
|
||||
class GLMTransformer(nn.Module):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
|
||||
@ -236,7 +252,7 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList(
|
||||
[GLMBlock(config) for i in range(self.num_layers)])
|
||||
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
@ -274,7 +290,11 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
class ChatGLMModel(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
|
||||
@ -283,15 +303,10 @@ class ChatGLMModel(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config)
|
||||
self.encoder = GLMTransformer(config, linear_method)
|
||||
|
||||
self.output_layer = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.padded_vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
params_dtype=config.torch_dtype,
|
||||
)
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -317,10 +332,15 @@ class ChatGLMModel(nn.Module):
|
||||
|
||||
class ChatGLMForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: ChatGLMConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.transformer = ChatGLMModel(config)
|
||||
self.linear_method = linear_method
|
||||
self.transformer = ChatGLMModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.output_layer.weight
|
||||
self.sampler = Sampler(config.padded_vocab_size)
|
||||
|
||||
@ -331,78 +351,26 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> Dict[int, SequenceOutputs]:
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.transformer(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"output_layer.weight",
|
||||
"embedding.weight",
|
||||
]
|
||||
_row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"]
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
q_proj_shard_size = self.config.hidden_size // tp_size
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.multi_query_group_num // tp_size)
|
||||
|
||||
mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size
|
||||
|
||||
state_dict = self.state_dict()
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_pos_emb.inv_freq" in name:
|
||||
continue
|
||||
if "word_embeddings" in name:
|
||||
name = name.replace(".word_embeddings", "")
|
||||
|
||||
if name in state_dict:
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
q_offset = q_proj_shard_size * tp_rank
|
||||
k_offset = (q_proj_shard_size * tp_size +
|
||||
kv_proj_shard_size * tp_rank)
|
||||
v_offset = (q_proj_shard_size * tp_size +
|
||||
kv_proj_shard_size * (tp_size + tp_rank))
|
||||
wq = loaded_weight[q_offset:q_offset + q_proj_shard_size]
|
||||
wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size]
|
||||
wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size]
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
param.data.copy_(loaded_weight)
|
||||
continue
|
||||
|
||||
if "dense_h_to_4h" in name:
|
||||
w_gate = loaded_weight[mlp_hidden_shard_size *
|
||||
tp_rank:mlp_hidden_shard_size *
|
||||
(tp_rank + 1)]
|
||||
w_proj = loaded_weight[mlp_hidden_shard_size *
|
||||
(tp_size +
|
||||
tp_rank):mlp_hidden_shard_size *
|
||||
(tp_size + tp_rank + 1)]
|
||||
loaded_weight = torch.cat([w_gate, w_proj], dim=0)
|
||||
param.data.copy_(loaded_weight)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
elif name == "transformer.rotary_pos_emb.inv_freq":
|
||||
continue
|
||||
else:
|
||||
print("Warning never found tensor's name:", name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -30,17 +30,19 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import (PagedAttention,
|
||||
PagedAttentionWithALiBi,
|
||||
PagedAttentionWithRoPE)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
@ -48,19 +50,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
FalconConfig = Union[HF_FalconConfig, RWConfig]
|
||||
|
||||
|
||||
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
|
||||
# training, this means that there's one additional quantization to bfloat16
|
||||
# between the operations. In order not to degrade the quality of our HF-port,
|
||||
# we keep these characteristics in the final model.
|
||||
class FalconLinear(nn.Linear):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = x @ self.weight.T
|
||||
if self.bias is None:
|
||||
return hidden_states
|
||||
return hidden_states + self.bias
|
||||
|
||||
|
||||
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
|
||||
base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
|
||||
@ -86,7 +75,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
||||
|
||||
class FalconAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -103,41 +96,29 @@ class FalconAttention(nn.Module):
|
||||
|
||||
if self.new_decoder_architecture:
|
||||
self.total_num_kv_heads = config.num_kv_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
elif self.multi_query:
|
||||
self.total_num_kv_heads = 1
|
||||
self.num_kv_heads = 1
|
||||
self.query = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
self.key_value = FalconLinear(self.hidden_size,
|
||||
2 * self.head_dim,
|
||||
bias=config.bias)
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
|
||||
@ -149,7 +130,6 @@ class FalconAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
@ -196,18 +176,10 @@ class FalconAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
if not self.new_decoder_architecture and self.multi_query:
|
||||
q, bias = self.query(hidden_states)
|
||||
if bias is not None:
|
||||
q += bias
|
||||
kv = self.key_value(hidden_states)
|
||||
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
|
||||
else:
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
qkv, bias = self.query_key_value(hidden_states)
|
||||
if bias is not None:
|
||||
qkv += bias
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k_cache, v_cache = kv_cache
|
||||
if self.use_rotary:
|
||||
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
|
||||
@ -221,15 +193,19 @@ class FalconAttention(nn.Module):
|
||||
|
||||
class FalconMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
skip_bias_add=True)
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method)
|
||||
self.act = nn.GELU()
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
@ -237,9 +213,9 @@ class FalconMLP(nn.Module):
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
reduce_results=self.reduce_row_parallel_results,
|
||||
linear_method=linear_method)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||
@ -253,12 +229,16 @@ class FalconMLP(nn.Module):
|
||||
|
||||
class FalconDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config)
|
||||
self.mlp = FalconMLP(config)
|
||||
self.self_attention = FalconAttention(config, linear_method)
|
||||
self.mlp = FalconMLP(config, linear_method)
|
||||
self.config = config
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
@ -334,7 +314,11 @@ class FalconDecoderLayer(nn.Module):
|
||||
|
||||
class FalconModel(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -349,7 +333,8 @@ class FalconModel(nn.Module):
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
FalconDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final Layer Norm
|
||||
@ -383,15 +368,18 @@ class FalconModel(nn.Module):
|
||||
|
||||
class FalconForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: FalconConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = FalconModel(config)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
self.linear_method = linear_method
|
||||
self.transformer = FalconModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -415,89 +403,44 @@ class FalconForCausalLM(nn.Module):
|
||||
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_size = (get_tensor_model_parallel_world_size())
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
hidden_size = self.config.hidden_size
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
if self.config.new_decoder_architecture:
|
||||
total_num_kv_heads = self.config.num_kv_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
elif self.config.multi_query:
|
||||
total_num_kv_heads = 1
|
||||
num_kv_heads = 1
|
||||
separated_q_kv = True
|
||||
kv_head_start = 0
|
||||
kv_head_end = 1
|
||||
else:
|
||||
total_num_kv_heads = total_num_heads
|
||||
num_kv_heads = total_num_kv_heads // tp_size
|
||||
separated_q_kv = False
|
||||
kv_head_start = tp_rank * num_kv_heads
|
||||
kv_head_end = (tp_rank + 1) * num_kv_heads
|
||||
num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
param = params_dict[name]
|
||||
if "query_key_value" in name:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight_size = loaded_weight.size()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
total_num_kv_heads, num_query_heads_per_kv_head + 2,
|
||||
head_size, *loaded_weight_size[1:])
|
||||
loaded_weight_shape[:output_dim] +
|
||||
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
wq = loaded_weight.narrow(
|
||||
output_dim + 1, 0, num_query_heads_per_kv_head).reshape(
|
||||
*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wk = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
wv = loaded_weight.narrow(
|
||||
output_dim + 1, num_query_heads_per_kv_head + 1,
|
||||
1).reshape(*loaded_weight_shape[:output_dim], -1,
|
||||
*loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
|
||||
|
||||
wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:])
|
||||
wk = loaded_weight[:, [-2]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
wv = loaded_weight[:, [-1]].reshape(-1,
|
||||
*loaded_weight_size[1:])
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
wk = wk[head_size * kv_head_start:head_size * kv_head_end]
|
||||
wv = wv[head_size * kv_head_start:head_size * kv_head_end]
|
||||
|
||||
if separated_q_kv:
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("query_key_value", "query")
|
||||
kv_weight_name = name.replace("query_key_value",
|
||||
"key_value")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank)
|
||||
continue
|
||||
else:
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -30,15 +30,17 @@ from transformers import GPT2Config
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -46,7 +48,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GPT2Attention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
@ -57,17 +63,18 @@ class GPT2Attention(nn.Module):
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
self.c_attn = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -95,6 +102,7 @@ class GPT2MLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -102,13 +110,13 @@ class GPT2MLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -121,16 +129,20 @@ class GPT2MLP(nn.Module):
|
||||
|
||||
class GPT2Block(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config)
|
||||
self.attn = GPT2Attention(config, linear_method)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config)
|
||||
self.mlp = GPT2MLP(inner_dim, config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -160,24 +172,23 @@ class GPT2Block(nn.Module):
|
||||
|
||||
class GPT2Model(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
assert not config.scale_attn_by_inverse_layer_idx
|
||||
assert not config.reorder_and_upcast_attn
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
# to 50304 in order to make it divisible by 64.
|
||||
# This improves performance since GPUs are faster if the dimension
|
||||
# is divisible by 64. In addition, it allows us to shard the embedding
|
||||
# layer across 2, 4, 8, or more GPUs.
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[GPT2Block(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([
|
||||
GPT2Block(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -207,12 +218,15 @@ class GPT2Model(nn.Module):
|
||||
|
||||
class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPT2Config):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = GPT2Model(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPT2Model(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -230,19 +244,12 @@ class GPT2LMHeadModel(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
@ -253,53 +260,19 @@ class GPT2LMHeadModel(nn.Module):
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
param = params_dict[name]
|
||||
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
|
||||
# Because of this, we need to transpose the weights.
|
||||
# Note(zhuohan): the logic below might break quantized models.
|
||||
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
|
||||
if conv1d_weight_name not in name:
|
||||
continue
|
||||
if not name.endswith(".weight"):
|
||||
continue
|
||||
loaded_weight = loaded_weight.t()
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -31,15 +31,17 @@ from transformers import GPTBigCodeConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -47,7 +49,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
total_num_heads = config.num_attention_heads
|
||||
@ -61,32 +67,26 @@ class GPTBigCodeAttention(nn.Module):
|
||||
|
||||
self.multi_query = config.multi_query
|
||||
if self.multi_query:
|
||||
total_num_kv_heads = 1
|
||||
self.num_kv_heads = 1
|
||||
self.kv_dim = self.head_dim
|
||||
self.c_attn_q = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||
2 * self.kv_dim,
|
||||
bias=True)
|
||||
else:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size + 2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.kv_dim = self.head_dim * self.num_kv_heads
|
||||
self.c_attn = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -100,17 +100,14 @@ class GPTBigCodeAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
if self.multi_query:
|
||||
q, _ = self.c_attn_q(hidden_states)
|
||||
kv = self.c_attn_kv(hidden_states)
|
||||
k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1)
|
||||
else:
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split([
|
||||
qkv, _ = self.c_attn(hidden_states)
|
||||
q, k, v = qkv.split(
|
||||
[
|
||||
self.hidden_size // self.tensor_model_parallel_world_size,
|
||||
self.kv_dim, self.kv_dim
|
||||
],
|
||||
dim=-1)
|
||||
dim=-1,
|
||||
)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, key_cache, value_cache,
|
||||
input_metadata, cache_event)
|
||||
@ -124,6 +121,7 @@ class GPTBigMLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -131,13 +129,13 @@ class GPTBigMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -150,16 +148,20 @@ class GPTBigMLP(nn.Module):
|
||||
|
||||
class GPTBigCodeBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
inner_dim = (config.n_inner if config.n_inner is not None else 4 *
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config)
|
||||
self.attn = GPTBigCodeAttention(config, linear_method)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -189,23 +191,23 @@ class GPTBigCodeBlock(nn.Module):
|
||||
|
||||
class GPTBigCodeModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert not config.add_cross_attention
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
|
||||
# Optimization: While the vocab size of GPT-2 is 50257, we extend it
|
||||
# to 50304 in order to make it divisible by 64.
|
||||
# This improves performance since GPUs are faster if the dimension
|
||||
# is divisible by 64. In addition, it allows us to shard the embedding
|
||||
# layer across 2, 4, 8, or more GPUs.
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim)
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([
|
||||
GPTBigCodeBlock(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -235,12 +237,15 @@ class GPTBigCodeModel(nn.Module):
|
||||
|
||||
class GPTBigCodeForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTBigCodeConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = GPTBigCodeModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPTBigCodeModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -258,89 +263,21 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["c_fc.weight", "c_fc.bias"]
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
# GPT-2 ties the weights of the embedding layer and the final
|
||||
# linear layer.
|
||||
continue
|
||||
if ".attn.bias" in name:
|
||||
# Skip attention mask.
|
||||
# NOTE: "c_attn.bias" should not be skipped.
|
||||
continue
|
||||
|
||||
if not name.startswith("transformer."):
|
||||
name = "transformer." + name
|
||||
|
||||
# For the fused QKV linear layer, manually shard the weights.
|
||||
if "c_attn" in name:
|
||||
# GPT-2's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor parallelism is used, we shard the weights along
|
||||
# the head dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
total_num_kv_heads = (1 if self.config.multi_query else
|
||||
total_num_heads)
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
total_kv_size = head_size * total_num_kv_heads
|
||||
num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
head_start = tensor_model_parallel_rank * num_heads
|
||||
head_end = (tensor_model_parallel_rank + 1) * num_heads
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
wq, wk, wv = torch.split(
|
||||
loaded_weight, [hidden_size, total_kv_size, total_kv_size],
|
||||
dim=0)
|
||||
|
||||
wq = wq[head_size * head_start:head_size * head_end]
|
||||
if not self.config.multi_query:
|
||||
# Split the heads when using normal multi-head attention
|
||||
wk = wk[head_size * head_start:head_size * head_end]
|
||||
wv = wv[head_size * head_start:head_size * head_end]
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
else:
|
||||
# For multi-query attention, we split the query
|
||||
# but replicate the key and value.
|
||||
loaded_weight_q = wq
|
||||
loaded_weight_kv = torch.cat([wk, wv], dim=0)
|
||||
q_weight_name = name.replace("c_attn", "c_attn_q")
|
||||
kv_weight_name = name.replace("c_attn", "c_attn_kv")
|
||||
load_tensor_parallel_weights(state_dict[q_weight_name],
|
||||
loaded_weight_q,
|
||||
q_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
load_tensor_parallel_weights(state_dict[kv_weight_name],
|
||||
loaded_weight_kv,
|
||||
kv_weight_name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if name == "transformer.wte.weight":
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -29,14 +29,17 @@ from transformers import GPTJConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -44,23 +47,28 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GPTJAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -102,18 +110,23 @@ class GPTJAttention(nn.Module):
|
||||
|
||||
class GPTJMLP(nn.Module):
|
||||
|
||||
def __init__(self, intermediate_size: int, config: GPTJConfig):
|
||||
def __init__(
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
self.fc_in = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.fc_out = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
@ -126,15 +139,19 @@ class GPTJMLP(nn.Module):
|
||||
|
||||
class GPTJBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
if config.n_inner is None:
|
||||
inner_dim = 4 * config.n_embd
|
||||
else:
|
||||
inner_dim = config.n_inner
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTJAttention(config)
|
||||
self.mlp = GPTJMLP(inner_dim, config)
|
||||
self.attn = GPTJAttention(config, linear_method)
|
||||
self.mlp = GPTJMLP(inner_dim, config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -160,7 +177,11 @@ class GPTJBlock(nn.Module):
|
||||
|
||||
class GPTJModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.n_embd
|
||||
@ -169,7 +190,7 @@ class GPTJModel(nn.Module):
|
||||
self.embed_dim,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTJBlock(config) for _ in range(config.n_layer)])
|
||||
[GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -200,15 +221,20 @@ class GPTJModel(nn.Module):
|
||||
|
||||
class GPTJForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTJConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
assert not config.tie_word_embeddings
|
||||
self.transformer = GPTJModel(config)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd,
|
||||
self.transformer = GPTJModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
gather_output=False,
|
||||
config.n_embd,
|
||||
bias=True,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -226,43 +252,33 @@ class GPTJForCausalLM(nn.Module):
|
||||
input_metadata, self.lm_head.bias)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight",
|
||||
"lm_head.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc_out.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -29,14 +29,17 @@ from transformers import GPTNeoXConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -44,7 +47,11 @@ KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
class GPTNeoXAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -56,15 +63,16 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.num_heads = (self.total_num_heads //
|
||||
tensor_model_parallel_world_size)
|
||||
|
||||
self.query_key_value = ColumnParallelLinear(
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
@ -100,17 +108,21 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
class GPTNeoXMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
@ -123,15 +135,19 @@ class GPTNeoXMLP(nn.Module):
|
||||
|
||||
class GPTNeoXLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config)
|
||||
self.mlp = GPTNeoXMLP(config)
|
||||
self.attention = GPTNeoXAttention(config, linear_method)
|
||||
self.mlp = GPTNeoXMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -169,7 +185,11 @@ class GPTNeoXLayer(nn.Module):
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -177,8 +197,10 @@ class GPTNeoXModel(nn.Module):
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList([
|
||||
GPTNeoXLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -210,15 +232,18 @@ class GPTNeoXModel(nn.Module):
|
||||
|
||||
class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
self.linear_method = linear_method
|
||||
self.gpt_neox = GPTNeoXModel(config, linear_method)
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -236,50 +261,35 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight",
|
||||
"dense_h_to_4h.bias"
|
||||
]
|
||||
_row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if ("attention.bias" in name or "attention.masked_bias" in name
|
||||
or "rotary_emb.inv_freq" in name):
|
||||
continue
|
||||
param = state_dict[name]
|
||||
if "query_key_value" in name:
|
||||
# NOTE(woosuk): GPT-NeoX's fused QKV has the shape of
|
||||
# [num_heads * 3 * head_size, hidden_size], while the
|
||||
# required shape is [3 * num_heads * head_size, hidden_size].
|
||||
# Thus, we need weight conversion.
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param = params_dict[name]
|
||||
|
||||
if "query_key_value" in name:
|
||||
# NOTE: GPT-NeoX's fused QKV's output_dim has the shape of
|
||||
# (num_heads * 3 * head_size), while the
|
||||
# required shape is (3 * num_heads * head_size).
|
||||
# Thus, we need weight conversion.
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // num_heads
|
||||
if "query_key_value.weight" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size,
|
||||
hidden_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "query_key_value.bias" in name:
|
||||
loaded_weight = loaded_weight.view(-1, 3, head_size)
|
||||
loaded_weight = loaded_weight.transpose(0, 1)
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected weight name: {name}")
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
if output_dim is not None:
|
||||
loaded_weight_shape = loaded_weight.shape
|
||||
loaded_weight = loaded_weight.view(
|
||||
loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
|
||||
loaded_weight_shape[output_dim + 1:])
|
||||
loaded_weight = loaded_weight.transpose(
|
||||
output_dim, output_dim + 1)
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight_shape)
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -9,15 +9,17 @@ from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -30,20 +32,17 @@ class InternLMMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -65,6 +64,7 @@ class InternLMAttention(nn.Module):
|
||||
bias: bool,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -79,17 +79,18 @@ class InternLMAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
@ -118,7 +119,11 @@ class InternLMAttention(nn.Module):
|
||||
|
||||
class InternLMDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -130,11 +135,13 @@ class InternLMDecoderLayer(nn.Module):
|
||||
bias=config.bias,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = InternLMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -171,7 +178,11 @@ class InternLMDecoderLayer(nn.Module):
|
||||
|
||||
class InternLMModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -183,7 +194,7 @@ class InternLMModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
InternLMDecoderLayer(config)
|
||||
InternLMDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -216,17 +227,16 @@ class InternLMModel(nn.Module):
|
||||
|
||||
class InternLMForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = InternLMModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.linear_method = linear_method
|
||||
self.model = InternLMModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -243,69 +253,33 @@ class InternLMForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
param = state_dict[name]
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -33,17 +33,19 @@ from transformers import LlamaConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -56,19 +58,17 @@ class LlamaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -91,7 +91,7 @@ class LlamaAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -109,7 +109,6 @@ class LlamaAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@ -117,21 +116,19 @@ class LlamaAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads +
|
||||
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
@ -165,11 +162,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
@ -181,13 +177,13 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -227,20 +223,18 @@ class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config, quant_config)
|
||||
LlamaDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -276,19 +270,13 @@ class LlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=None)
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -305,124 +293,33 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
col_weight_suffixes = ["weight"]
|
||||
row_weight_suffixes = ["weight"]
|
||||
else:
|
||||
col_weight_suffixes = (
|
||||
self.quant_config.get_col_parallel_tensor_names())
|
||||
row_weight_suffixes = (
|
||||
self.quant_config.get_row_parallel_tensor_names())
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in col_weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in row_weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
num_kv_heads_replicas = max(1,
|
||||
tp_size // self.config.num_key_value_heads)
|
||||
num_kv_heads_per_gpu = max(1,
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
num_kv_heads_per_gpu)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
packed_dim = None
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
packed_dim = self.quant_config.get_packed_dim(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if packed_dim is not None:
|
||||
shard_dim = 0 if not is_transposed else 1
|
||||
if packed_dim == shard_dim:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
if weight_name in ["k_proj", "v_proj"]:
|
||||
shard_id = tp_rank // num_kv_heads_replicas
|
||||
else:
|
||||
shard_id = tp_rank
|
||||
loaded_weight = loaded_weight[shard_size *
|
||||
shard_id:shard_size *
|
||||
(shard_id + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights, tp_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -33,17 +33,19 @@ from transformers import MistralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -56,19 +58,17 @@ class MistralMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -89,7 +89,7 @@ class MistralAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -98,8 +98,15 @@ class MistralAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@ -107,20 +114,19 @@ class MistralAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -153,7 +159,7 @@ class MistralDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -165,13 +171,13 @@ class MistralDecoderLayer(nn.Module):
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
sliding_window=config.sliding_window)
|
||||
self.mlp = MistralMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -211,20 +217,19 @@ class MistralModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MistralDecoderLayer(config, quant_config)
|
||||
MistralDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -260,19 +265,13 @@ class MistralForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MistralConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = MistralModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=None)
|
||||
self.linear_method = linear_method
|
||||
self.model = MistralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -289,118 +288,33 @@ class MistralForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
col_weight_suffixes = ["weight"]
|
||||
row_weight_suffixes = ["weight"]
|
||||
else:
|
||||
col_weight_suffixes = (
|
||||
self.quant_config.get_col_parallel_tensor_names())
|
||||
row_weight_suffixes = (
|
||||
self.quant_config.get_row_parallel_tensor_names())
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in col_weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in row_weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
packed_dim = None
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
packed_dim = self.quant_config.get_packed_dim(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if packed_dim is not None:
|
||||
shard_dim = 0 if not is_transposed else 1
|
||||
if packed_dim == shard_dim:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tensor_model_parallel_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -10,15 +10,17 @@ from transformers import MptConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -39,7 +41,11 @@ def _get_alibi_slopes(
|
||||
|
||||
class MptAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: MptConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
@ -49,11 +55,13 @@ class MptAttention(nn.Module):
|
||||
assert not config.attn_config.prefix_lm
|
||||
assert config.attn_config.alibi
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
# pylint: disable=invalid-name
|
||||
self.Wqkv = QKVParallelLinear(
|
||||
self.d_model,
|
||||
3 * self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
if self.qk_ln:
|
||||
self.q_ln = nn.LayerNorm(self.d_model)
|
||||
@ -62,7 +70,7 @@ class MptAttention(nn.Module):
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -91,7 +99,7 @@ class MptAttention(nn.Module):
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
del position_ids # unused.
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
@ -107,7 +115,11 @@ class MptAttention(nn.Module):
|
||||
|
||||
class MptMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: MptConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
expansion_ratio = config.expansion_ratio
|
||||
@ -116,14 +128,14 @@ class MptMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -135,13 +147,17 @@ class MptMLP(nn.Module):
|
||||
|
||||
class MptBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: MptConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MptAttention(config)
|
||||
self.attn = MptAttention(config, linear_method)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MptMLP(config)
|
||||
self.ffn = MptMLP(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -168,7 +184,11 @@ class MptBlock(nn.Module):
|
||||
|
||||
class MptModel(nn.Module):
|
||||
|
||||
def __init__(self, config: MptConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert config.embedding_fraction == 1.0
|
||||
assert config.norm_type == "low_precision_layernorm"
|
||||
@ -178,7 +198,7 @@ class MptModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MptBlock(config) for _ in range(config.n_layers)])
|
||||
[MptBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
@ -215,14 +235,17 @@ class MptModel(nn.Module):
|
||||
|
||||
class MptForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config: MptConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: MptConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.tie_word_embeddings
|
||||
self.linear_method = linear_method
|
||||
|
||||
self.transformer = MptModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.transformer = MptModel(config, linear_method)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -240,45 +263,15 @@ class MptForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
|
||||
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "Wqkv" in name:
|
||||
# NOTE(woosuk): MPT's fused QKV has the shape of
|
||||
# [3 * num_heads * head_size, hidden_size].
|
||||
# When tensor model parallelism is used, we need to shard
|
||||
# the weight along the hidden dimension.
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
if name.endswith(".weight"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif name.endswith(".bias"):
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
else:
|
||||
raise ValueError(f"Unexpected parameter name {name}")
|
||||
name = name.replace("Wqkv", "qkv_proj")
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights, tp_rank)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -30,14 +30,18 @@ from transformers import OPTConfig
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -63,6 +67,7 @@ class OPTAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -74,17 +79,18 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -108,7 +114,11 @@ class OPTAttention(nn.Module):
|
||||
|
||||
class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -116,6 +126,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
@ -127,13 +138,13 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@ -177,7 +188,11 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
class OPTDecoder(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -194,16 +209,18 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False)
|
||||
self.project_out = ReplicatedLinear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_in = nn.Linear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
@ -218,8 +235,10 @@ class OPTDecoder(nn.Module):
|
||||
else:
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.layers = nn.ModuleList([
|
||||
OPTDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -253,9 +272,13 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config)
|
||||
self.decoder = OPTDecoder(config, linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -271,12 +294,15 @@ class OPTModel(nn.Module):
|
||||
|
||||
class OPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = OPTModel(config)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.linear_method = linear_method
|
||||
self.model = OPTModel(config, linear_method)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
@ -294,48 +320,31 @@ class OPTForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"embed_tokens.weight", "fc1.weight", "fc1.bias"
|
||||
]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
|
||||
if name.startswith("decoder."):
|
||||
name = "model." + name
|
||||
|
||||
is_attention_weight = False
|
||||
for stride_id, att_weight_name in enumerate(
|
||||
["q_proj", "k_proj", "v_proj"]):
|
||||
if att_weight_name not in name:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(att_weight_name, "qkv_proj")]
|
||||
shard_size = param.shape[0] // 3
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
(tensor_model_parallel_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -15,24 +15,19 @@ from torch import nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor,
|
||||
hf_model_weights_iterator,
|
||||
load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.layers import (
|
||||
VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.qwen import QWenConfig
|
||||
|
||||
@ -46,20 +41,17 @@ class QWenMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
linear_method=linear_method)
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -74,12 +66,15 @@ class QWenMLP(nn.Module):
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
@ -90,18 +85,18 @@ class QWenAttention(nn.Module):
|
||||
tensor_model_parallel_world_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
self.c_attn = QKVParallelLinear(
|
||||
hidden_size,
|
||||
3 * hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
@ -134,7 +129,11 @@ class QWenAttention(nn.Module):
|
||||
|
||||
class QWenBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: QWenConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -144,11 +143,14 @@ class QWenBlock(nn.Module):
|
||||
config.num_attention_heads,
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling)
|
||||
rope_scaling=rope_scaling,
|
||||
linear_method=linear_method)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2)
|
||||
self.mlp = QWenMLP(config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
linear_method=linear_method)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -180,18 +182,23 @@ class QWenBlock(nn.Module):
|
||||
|
||||
class QWenModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: QWenConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.h = nn.ModuleList([
|
||||
QWenBlock(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -222,17 +229,16 @@ class QWenModel(nn.Module):
|
||||
|
||||
class QWenLMHeadModel(nn.Module):
|
||||
|
||||
def __init__(self, config: QWenConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: QWenConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = QWenModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.linear_method = linear_method
|
||||
self.transformer = QWenModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -249,75 +255,30 @@ class QWenLMHeadModel(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = []
|
||||
_row_parallel_weights = ["c_proj.weight"]
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w2", 0),
|
||||
("gate_up_proj", "w1", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
|
||||
if "c_attn" in name:
|
||||
total_num_heads = self.config.num_attention_heads
|
||||
hidden_size = self.config.hidden_size
|
||||
head_size = hidden_size // total_num_heads
|
||||
num_heads = total_num_heads // tp_world_size
|
||||
head_start = tp_rank * num_heads
|
||||
head_end = (tp_rank + 1) * num_heads
|
||||
|
||||
if "weight" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size, hidden_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
|
||||
loaded_weight = loaded_weight.reshape(-1, hidden_size)
|
||||
elif "bias" in name:
|
||||
loaded_weight = loaded_weight.view(3, total_num_heads,
|
||||
head_size)
|
||||
loaded_weight = loaded_weight[:, head_start:head_end, :]
|
||||
loaded_weight = loaded_weight.reshape(-1)
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["w2", "w1"]):
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
|
||||
if "wte" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
tp_rank,
|
||||
)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -33,17 +33,19 @@ from vllm.transformers_utils.configs.yi import YiConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab)
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -56,19 +58,17 @@ class YiMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config)
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -91,7 +91,7 @@ class YiAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -109,7 +109,6 @@ class YiAttention(nn.Module):
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@ -117,21 +116,19 @@ class YiAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
(self.total_num_heads +
|
||||
2 * self.total_num_kv_heads * num_kv_heads_replicas) *
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
@ -165,11 +162,10 @@ class YiDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: YiConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
@ -181,13 +177,13 @@ class YiDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.mlp = YiMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -225,20 +221,18 @@ class YiModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: YiConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
YiDecoderLayer(config, quant_config)
|
||||
YiDecoderLayer(config, linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -274,19 +268,13 @@ class YiForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: YiConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = YiModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
quant_config=None)
|
||||
self.linear_method = linear_method
|
||||
self.model = YiModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -303,124 +291,33 @@ class YiForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
col_weight_suffixes = ["weight"]
|
||||
row_weight_suffixes = ["weight"]
|
||||
else:
|
||||
col_weight_suffixes = (
|
||||
self.quant_config.get_col_parallel_tensor_names())
|
||||
row_weight_suffixes = (
|
||||
self.quant_config.get_row_parallel_tensor_names())
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in col_weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in row_weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
num_kv_heads_replicas = max(1,
|
||||
tp_size // self.config.num_key_value_heads)
|
||||
num_kv_heads_per_gpu = max(1,
|
||||
self.config.num_key_value_heads // tp_size)
|
||||
kv_proj_shard_size = (self.config.hidden_size //
|
||||
self.config.num_attention_heads *
|
||||
num_kv_heads_per_gpu)
|
||||
attention_weight_specs = [
|
||||
# (weight_name, shard_size, offset)
|
||||
("q_proj", q_proj_shard_size, 0),
|
||||
("k_proj", kv_proj_shard_size, q_proj_shard_size),
|
||||
("v_proj", kv_proj_shard_size,
|
||||
q_proj_shard_size + kv_proj_shard_size),
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
state_dict = self.state_dict()
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
packed_dim = None
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
packed_dim = self.quant_config.get_packed_dim(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if packed_dim is not None:
|
||||
shard_dim = 0 if not is_transposed else 1
|
||||
if packed_dim == shard_dim:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
if weight_name in ["k_proj", "v_proj"]:
|
||||
shard_id = tp_rank // num_kv_heads_replicas
|
||||
else:
|
||||
shard_id = tp_rank
|
||||
loaded_weight = loaded_weight[shard_size *
|
||||
shard_id:shard_size *
|
||||
(shard_id + 1)]
|
||||
param_slice = param.data[offset:offset + shard_size]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_attention_weight = True
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
if is_attention_weight:
|
||||
continue
|
||||
|
||||
is_gate_up_weight = False
|
||||
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[shard_size * tp_rank:shard_size *
|
||||
(tp_rank + 1)]
|
||||
param_slice = param.data[shard_size * stride_id:shard_size *
|
||||
(stride_id + 1)]
|
||||
assert param_slice.shape == loaded_weight.shape
|
||||
param_slice.copy_(loaded_weight)
|
||||
is_gate_up_weight = True
|
||||
break
|
||||
if is_gate_up_weight:
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
tp_rank)
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights, tp_rank)
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
@ -1,303 +0,0 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
|
||||
from vllm.model_executor.parallel_utils.utils import (
|
||||
divide,
|
||||
VocabUtility,
|
||||
split_tensor_along_last_dim,
|
||||
)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# TODO: Handle vocab padding here.
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, get_tensor_model_parallel_rank(),
|
||||
self.tp_size))
|
||||
self.num_embeddings_per_partition = (self.vocab_end_index -
|
||||
self.vocab_start_index)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = ((input_ < self.vocab_start_index) |
|
||||
(input_ >= self.vocab_end_index))
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Parameters.
|
||||
# NOTE: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, bias)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments:
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
skip_bias_add: This was added to enable performance optimization where
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError('When not reduce the results, adding bias to the '
|
||||
'results can lead to incorrect results')
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
||||
@ -2,7 +2,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
from typing import List, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
@ -24,7 +24,7 @@ def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> Sequence[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
@ -46,25 +46,3 @@ def split_tensor_along_last_dim(
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
""" Split the vocabulary into `world_size` chunks and return the first
|
||||
and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
|
||||
world_size: int) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank)
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig
|
||||
|
||||
_QUANTIZATION_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quant_class",
|
||||
]
|
||||
@ -1,76 +0,0 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Turing or newer GPUs.
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
return {"qweight": 1, "qzeros": 1}
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
@ -1,85 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantizationConfig:
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
"""Returns a dictionary of packed tensor names and their pack dims."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_packed_dim(cls, tensor_name: str) -> Optional[int]:
|
||||
"""Returns the pack dim of a tensor if it is packed.
|
||||
|
||||
A tensor is considered packed if each element in the tensor is a
|
||||
packed representation of multiple elements in the original tensor.
|
||||
For example, an INT32 element in the tensor may represent 8 INT4
|
||||
elements in the original tensor.
|
||||
If the tensor is not packed, returns None.
|
||||
"""
|
||||
packed_tensors = cls.get_packed_tensors()
|
||||
for packed_tensor_name, pack_dim in packed_tensors.items():
|
||||
if packed_tensor_name in tensor_name:
|
||||
return pack_dim
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_transposed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||
"""
|
||||
return any(tag in tensor_name
|
||||
for tag in cls.get_transposed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
@ -1,65 +0,0 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
"""Config class for SqueezeLLM.
|
||||
|
||||
Reference: https://arxiv.org/pdf/2306.07629
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"SqueezeLLM, but got {self.weight_bits} bits.")
|
||||
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "squeezellm"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensors(cls) -> Dict[str, int]:
|
||||
return {"qweight": 0}
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight"]
|
||||
|
||||
@classmethod
|
||||
def get_col_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "lookup_table"]
|
||||
|
||||
@classmethod
|
||||
def get_row_parallel_tensor_names(cls) -> List[str]:
|
||||
return ["qweight"]
|
||||
@ -1,5 +1,6 @@
|
||||
"""Utils for model executor."""
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,3 +12,24 @@ def set_random_seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def set_weight_attrs(
|
||||
weight: torch.Tensor,
|
||||
weight_attrs: Optional[Dict[str, Any]],
|
||||
):
|
||||
"""Set attributes on a weight tensor.
|
||||
|
||||
This method is used to set attributes on a weight tensor. This method
|
||||
will not overwrite existing attributes.
|
||||
|
||||
Args:
|
||||
weight: The weight tensor.
|
||||
weight_attrs: A dictionary of attributes to set on the weight tensor.
|
||||
"""
|
||||
if weight_attrs is None:
|
||||
return
|
||||
for key, value in weight_attrs.items():
|
||||
assert not hasattr(
|
||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||
setattr(weight, key, value)
|
||||
|
||||
@ -13,8 +13,8 @@ import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.quantization_utils import get_quant_class
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization import (get_quantization_config,
|
||||
QuantizationConfig)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -98,7 +98,7 @@ def get_quant_config(
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_cls = get_quant_class(quantization)
|
||||
quant_cls = get_quantization_config(quantization)
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
@ -237,7 +237,7 @@ def hf_model_weights_iterator(
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys():
|
||||
param = f.get_slice(name)
|
||||
yield name, param
|
||||
yield name, convert_pyslice_to_tensor(param)
|
||||
else:
|
||||
for bin_file in hf_weights_files:
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
@ -262,46 +262,10 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def load_padded_tensor_parallel_vocab(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
tensor_model_parallel_rank: int,
|
||||
) -> None:
|
||||
shard_size = param.shape[0]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
param[:loaded_weight.shape[0]].copy_(loaded_weight)
|
||||
|
||||
|
||||
def load_tensor_parallel_weights(
|
||||
param: torch.Tensor,
|
||||
loaded_weight: Any, # `torch.Tensor` or `PySafeSlice`
|
||||
param_name: str,
|
||||
column_parallel_weight_names: List[str],
|
||||
row_parallel_weight_names: List[str],
|
||||
tensor_model_parallel_rank: int,
|
||||
) -> None:
|
||||
for p in column_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[0]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[start_idx:end_idx]
|
||||
break
|
||||
for p in row_parallel_weight_names:
|
||||
if p in param_name:
|
||||
shard_size = param.shape[1]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
loaded_weight = loaded_weight[:, start_idx:end_idx]
|
||||
break
|
||||
|
||||
loaded_weight = convert_pyslice_to_tensor(loaded_weight)
|
||||
assert param.shape == loaded_weight.shape, (
|
||||
f"{param_name} shape mismatch between model and checkpoint: "
|
||||
f"{param.shape} != {loaded_weight.shape}")
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user