[BugFix] Fix TP support for AWQ (#1731)

This commit is contained in:
Woosuk Kwon 2023-11-20 21:42:45 -08:00 committed by GitHub
parent 4bb6b67188
commit cf35d8f3d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 14 deletions

View File

@ -6,6 +6,10 @@ import torch.nn as nn
from vllm import activation_ops
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.utils import divide
from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module):
@ -51,17 +55,38 @@ class ScaledActivation(nn.Module):
def __init__(
self,
act_module: nn.Module,
hidden_size: int,
params_dtype: torch.dtype,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.act = act_module
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size,
tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(hidden_size, dtype=params_dtype, device="cuda"))
torch.empty(intermediate_size_per_partition,
dtype=params_dtype,
device="cuda"))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
@ -76,6 +101,8 @@ def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
@ -84,14 +111,11 @@ def get_act_fn(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names(
):
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(
act_fn,
intermediate_size,
params_dtype=torch.get_default_dtype(),
)
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn

View File

@ -129,9 +129,6 @@ class OPTDecoderLayer(nn.Module):
linear_method=linear_method,
)
self.do_layer_norm_before = config.do_layer_norm_before
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
@ -142,6 +139,9 @@ class OPTDecoderLayer(nn.Module):
bias=config.enable_bias,
linear_method=linear_method,
)
quant_config = getattr(linear_method, "quant_config", None)
self.activation_fn = get_act_fn(config.activation_function,
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,