mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
[BugFix] Fix TP support for AWQ (#1731)
This commit is contained in:
parent
4bb6b67188
commit
cf35d8f3d7
@ -6,6 +6,10 @@ import torch.nn as nn
|
||||
|
||||
from vllm import activation_ops
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import divide
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
@ -51,17 +55,38 @@ class ScaledActivation(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
hidden_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
|
||||
torch.empty(intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cuda"))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
@ -76,6 +101,8 @@ def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
@ -84,14 +111,11 @@ def get_act_fn(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names(
|
||||
):
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(
|
||||
act_fn,
|
||||
intermediate_size,
|
||||
params_dtype=torch.get_default_dtype(),
|
||||
)
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module):
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module):
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user