mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:54:56 +08:00
[Misc][Refactor] Generalize linear_method to be quant_method (#4373)
This commit is contained in:
parent
603ad84815
commit
a62aaf1df5
@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None:
|
||||
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.linear_method, Fp8LinearMethod)
|
||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config):
|
||||
mock_agent_instance.deserialize.return_value = MagicMock()
|
||||
|
||||
result = load_with_tensorizer(tensorizer_config,
|
||||
linear_method=mock_linear_method)
|
||||
quant_method=mock_linear_method)
|
||||
|
||||
mock_agent.assert_called_once_with(tensorizer_config,
|
||||
linear_method=mock_linear_method)
|
||||
quant_method=mock_linear_method)
|
||||
mock_agent_instance.deserialize.assert_called_once()
|
||||
assert result == mock_agent_instance.deserialize.return_value
|
||||
|
||||
|
||||
@ -389,10 +389,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -416,7 +415,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
if not self.base_layer.skip_bias_add else None)
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_, bias)
|
||||
output_parallel = self.apply(input_, bias)
|
||||
if self.base_layer.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@ -523,10 +522,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -765,10 +763,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||
lora_a[2].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -862,9 +859,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x)
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -897,7 +893,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
output_parallel = self.apply(input_parallel)
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
class LinearMethodBase(QuantizeMethodBase):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
@ -50,22 +51,15 @@ class LinearMethodBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class UnquantizedLinearMethod(LinearMethodBase):
|
||||
"""Linear method without quantization.
|
||||
@ -92,10 +86,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
if self.separate_bias_add:
|
||||
if bias is not None:
|
||||
@ -104,8 +98,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
return F.linear(x, weight, bias)
|
||||
|
||||
|
||||
class ReplicatedLinear(torch.nn.Module):
|
||||
"""Replicated linear layer.
|
||||
class LinearBase(torch.nn.Module):
|
||||
"""Base linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -134,12 +127,43 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
"""Replicated linear layer.
|
||||
|
||||
Args:
|
||||
input_size: input dimension of the linear layer.
|
||||
output_size: output dimension of the linear layer.
|
||||
bias: If true, add bias.
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
self.quant_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
@ -149,12 +173,12 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self, x, bias)
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
@ -171,7 +195,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
"""
|
||||
@ -184,34 +208,26 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@ -239,7 +255,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(self, input_, bias)
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@ -267,7 +283,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -278,13 +294,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, linear_method,
|
||||
skip_bias_add, params_dtype, quant_config,
|
||||
self.output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
@ -384,7 +400,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -396,7 +412,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
@ -424,7 +440,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, linear_method, output_sizes)
|
||||
params_dtype, quant_config, output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@ -517,7 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
class RowParallelLinear(LinearBase):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
@ -540,7 +556,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -552,32 +568,24 @@ class RowParallelLinear(torch.nn.Module):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
@ -616,8 +624,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self, input_parallel)
|
||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
@ -4,7 +4,7 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import FP8Config
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
QUANTIZATION_METHODS = {
|
||||
"aqlm": AQLMConfig,
|
||||
"awq": AWQConfig,
|
||||
"fp8": FP8Config,
|
||||
"fp8": Fp8Config,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"marlin": MarlinConfig,
|
||||
|
||||
@ -9,10 +9,10 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
def get_int_dtype(nbits: int) -> torch.dtype:
|
||||
@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
|
||||
return cls(in_group_size, nbits_per_codebook, num_code_books,
|
||||
out_group_size)
|
||||
|
||||
def get_linear_method(self) -> "AQLMLinearMethod":
|
||||
return AQLMLinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return AQLMLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def apply_weights(
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -4,10 +4,10 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
def get_linear_method(self) -> "AWQLinearMethod":
|
||||
return AWQLinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return AWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
|
||||
@ -2,8 +2,33 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module, *weight_args,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||
"""Process the weight after loading.
|
||||
|
||||
This can be used for example, to transpose weights for computation.
|
||||
"""
|
||||
return
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
@ -51,8 +76,8 @@ class QuantizationConfig(ABC):
|
||||
"quantization config.")
|
||||
|
||||
@abstractmethod
|
||||
def get_linear_method(self) -> LinearMethodBase:
|
||||
"""Get the linear method to use for the quantized linear layer."""
|
||||
def get_quant_method(self, layer: torch.nn.Module) -> QuantizeMethodBase:
|
||||
"""Get the quantize method to use for the quantized layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class FP8Config(QuantizationConfig):
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
@classmethod
|
||||
@ -33,11 +34,14 @@ class FP8Config(QuantizationConfig):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "FP8Config":
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
return cls()
|
||||
|
||||
def get_linear_method(self) -> "Fp8LinearMethod":
|
||||
return Fp8LinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@ -57,7 +61,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
quant_config: The quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: FP8Config):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
@ -86,24 +90,24 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("weight_scaling_factor", w_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# Although the linear_method is propagated to all layers,
|
||||
# Although the quant_method is propagated to all layers,
|
||||
# only linear layers invoke "create_weights". So we check
|
||||
# whether "weight_scaling_facor" is registered to determine
|
||||
# whether the layer is a linear layer that requires quantization.
|
||||
if not hasattr(layer, "weight_scaling_factor"):
|
||||
return
|
||||
|
||||
qweight, weight_scale = per_tensor_quantize(layer.weight)
|
||||
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight)
|
||||
# torch._scaled_mm requires column-major in the second
|
||||
# input (weight), so we transpose the quantized weight.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scaling_factor.data.copy_(weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = per_tensor_quantize(x)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x)
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
@ -113,27 +117,3 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
|
||||
"""Quantize a tensor using per-tensor static scaling factor.
|
||||
|
||||
Args:
|
||||
tensor: The input tensor.
|
||||
"""
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
# Calculate the scale as dtype max divided by absmax.
|
||||
# Since .abs() creates a new tensor, we use aminmax to get
|
||||
# the min and max first and then calculate the absmax.
|
||||
min_val, max_val = tensor.aminmax()
|
||||
amax = min_val.abs().max(max_val.abs())
|
||||
scale = finfo.max / amax.clamp(min=1e-12)
|
||||
# scale and clamp the tensor to bring it to
|
||||
# the representative range of float8 data type
|
||||
# (as default cast is unsaturated)
|
||||
qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
# Return both float8 data and the inverse scale (as float),
|
||||
# as both required as inputs to torch._scaled_mm
|
||||
qweight = qweight.to(torch.float8_e4m3fn)
|
||||
scale = scale.float().reciprocal()
|
||||
return qweight, scale
|
||||
|
||||
@ -7,10 +7,10 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
@ -63,8 +63,11 @@ class GPTQConfig(QuantizationConfig):
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
return cls(weight_bits, group_size, desc_act)
|
||||
|
||||
def get_linear_method(self) -> "GPTQLinearMethod":
|
||||
return GPTQLinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@ -194,10 +197,10 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
@ -4,10 +4,10 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
@ -72,8 +72,11 @@ class MarlinConfig(QuantizationConfig):
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(group_size)
|
||||
|
||||
def get_linear_method(self) -> "MarlinLinearMethod":
|
||||
return MarlinLinearMethod(self)
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
@ -197,7 +200,7 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("workspace", workspace)
|
||||
set_weight_attrs(workspace, extra_weight_attrs)
|
||||
|
||||
def apply_weights(
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
|
||||
@ -4,10 +4,10 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
@ -51,14 +51,18 @@ class SqueezeLLMConfig(QuantizationConfig):
|
||||
weight_bits = cls.get_from_keys(config, ["wbits"])
|
||||
return cls(weight_bits)
|
||||
|
||||
def get_linear_method(self) -> "SqueezeLLMLinearMethod":
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module) -> Optional["SqueezeLLMLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SqueezeLLMLinearMethod(self)
|
||||
return
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
class SqueezeLLMLinearMethod(QuantizeMethodBase):
|
||||
"""Linear method for SqueezeLLM.
|
||||
|
||||
Args:
|
||||
@ -112,10 +116,10 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("lookup_table", lookup_table)
|
||||
set_weight_attrs(lookup_table, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = layer.qweight
|
||||
lookup_table = layer.lookup_table
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
|
||||
@ -3,8 +3,7 @@ import copy
|
||||
import glob
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple,
|
||||
Type)
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -13,6 +12,8 @@ from vllm.config import (VLLM_USE_MODELSCOPE, DeviceConfig, LoadConfig,
|
||||
LoadFormat, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
|
||||
tensorizer_weights_iterator)
|
||||
@ -24,9 +25,6 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
|
||||
_VISION_MODEL_CLASSES = [
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
@ -34,11 +32,10 @@ _VISION_MODEL_CLASSES = [
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_linear_method(
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional["LinearMethodBase"]:
|
||||
"""Get the (maybe quantized) linear method."""
|
||||
linear_method = None
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
@ -55,9 +52,8 @@ def _get_linear_method(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
|
||||
linear_method = quant_config.get_linear_method()
|
||||
return linear_method
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_initialization_kwargs(
|
||||
@ -85,10 +81,10 @@ def _initialize_model(
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
linear_method = _get_linear_method(model_config, load_config)
|
||||
quant_config = _get_quantization_config(model_config, load_config)
|
||||
|
||||
return model_class(config=model_config.hf_config,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
**_get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config))
|
||||
|
||||
@ -229,9 +225,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
"fall_back_to_pt_during_load",
|
||||
True)), )
|
||||
for _, module in model.named_modules():
|
||||
linear_method = getattr(module, "linear_method", None)
|
||||
if linear_method is not None:
|
||||
linear_method.process_weights_after_loading(module)
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
return model.eval()
|
||||
@ -314,11 +312,11 @@ class TensorizerLoader(BaseModelLoader):
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
linear_method = _get_linear_method(model_config,
|
||||
self.load_config)
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, self.load_config)
|
||||
extra_kwargs = _get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config)
|
||||
extra_kwargs["linear_method"] = linear_method
|
||||
extra_kwargs["quant_config"] = quant_config
|
||||
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
|
||||
@ -13,7 +13,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
|
||||
@ -251,7 +252,7 @@ class TensorizerAgent:
|
||||
"""
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig,
|
||||
linear_method: LinearMethodBase, **extra_kwargs):
|
||||
quant_config: QuantizationConfig, **extra_kwargs):
|
||||
if tensorizer_load_fail is not None:
|
||||
raise ImportError(
|
||||
"Tensorizer is not installed. Please install tensorizer "
|
||||
@ -262,10 +263,10 @@ class TensorizerAgent:
|
||||
self.tensorizer_args = (
|
||||
self.tensorizer_config._construct_tensorizer_args())
|
||||
self.extra_kwargs = extra_kwargs
|
||||
if extra_kwargs.get("linear_method", None) is not None:
|
||||
self.linear_method = extra_kwargs["linear_method"]
|
||||
if extra_kwargs.get("quant_config", None) is not None:
|
||||
self.quant_config = extra_kwargs["quant_config"]
|
||||
else:
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.model = self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
@ -274,7 +275,7 @@ class TensorizerAgent:
|
||||
with no_init_or_tensor():
|
||||
return self.tensorizer_config.model_class(
|
||||
config=model_args,
|
||||
linear_method=self.linear_method,
|
||||
quant_config=self.quant_config,
|
||||
**self.extra_kwargs)
|
||||
|
||||
def _resize_lora_embeddings(self):
|
||||
|
||||
@ -31,11 +31,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -77,17 +78,17 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -110,7 +111,7 @@ class BaiChuanAttention(nn.Module):
|
||||
position_embedding: str,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -132,13 +133,13 @@ class BaiChuanAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -196,13 +197,13 @@ class BaiChuanDecoderLayer(nn.Module):
|
||||
position_embedding=position_embedding,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = BaiChuanMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module):
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -254,7 +255,7 @@ class BaiChuanModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding, linear_method)
|
||||
BaiChuanDecoderLayer(config, position_embedding, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -303,13 +304,13 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
self,
|
||||
config,
|
||||
position_embedding: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = BaiChuanModel(config, position_embedding, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", linear_method, lora_config)
|
||||
super().__init__(config, "ALIBI", quant_config, lora_config)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__(config, "ROPE", linear_method, lora_config)
|
||||
super().__init__(config, "ROPE", quant_config, lora_config)
|
||||
|
||||
@ -28,10 +28,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -70,7 +71,7 @@ class BloomAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -87,13 +88,13 @@ class BloomAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
@ -129,21 +130,21 @@ class BloomMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -158,17 +159,17 @@ class BloomBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size,
|
||||
eps=config.layer_norm_epsilon)
|
||||
self.self_attention = BloomAttention(config, linear_method)
|
||||
self.self_attention = BloomAttention(config, quant_config)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = BloomMLP(config, linear_method)
|
||||
self.mlp = BloomMLP(config, quant_config)
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
config.apply_residual_connection_post_layernorm)
|
||||
|
||||
@ -214,7 +215,7 @@ class BloomModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@ -229,7 +230,7 @@ class BloomModel(nn.Module):
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
BloomBlock(config, linear_method)
|
||||
BloomBlock(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -262,12 +263,12 @@ class BloomForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: BloomConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = BloomModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = BloomModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -13,11 +13,12 @@ from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -33,7 +34,7 @@ class GLMAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -65,13 +66,13 @@ class GLMAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
@ -123,7 +124,7 @@ class GLMMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -134,7 +135,7 @@ class GLMMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.activation_func = SiluAndMul()
|
||||
@ -144,7 +145,7 @@ class GLMMLP(nn.Module):
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -166,7 +167,7 @@ class GLMBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.apply_residual_connection_post_layernorm = (
|
||||
@ -180,7 +181,7 @@ class GLMBlock(nn.Module):
|
||||
eps=config.layernorm_epsilon)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = GLMAttention(config, linear_method)
|
||||
self.self_attention = GLMAttention(config, quant_config)
|
||||
self.hidden_dropout = config.hidden_dropout
|
||||
|
||||
# Layernorm on the attention output
|
||||
@ -188,7 +189,7 @@ class GLMBlock(nn.Module):
|
||||
config.hidden_size, eps=config.layernorm_epsilon)
|
||||
|
||||
# MLP
|
||||
self.mlp = GLMMLP(config, linear_method)
|
||||
self.mlp = GLMMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -236,7 +237,7 @@ class GLMTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.post_layer_norm = config.post_layer_norm
|
||||
@ -246,7 +247,7 @@ class GLMTransformer(nn.Module):
|
||||
|
||||
# Transformer layers.
|
||||
self.layers = nn.ModuleList(
|
||||
[GLMBlock(config, linear_method) for i in range(self.num_layers)])
|
||||
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
|
||||
|
||||
if self.post_layer_norm:
|
||||
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
|
||||
@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -291,7 +292,7 @@ class ChatGLMModel(nn.Module):
|
||||
self.num_layers = config.num_layers
|
||||
self.multi_query_group_num = config.multi_query_group_num
|
||||
self.kv_channels = config.kv_channels
|
||||
self.encoder = GLMTransformer(config, linear_method)
|
||||
self.encoder = GLMTransformer(config, quant_config)
|
||||
|
||||
self.output_layer = ParallelLMHead(config.padded_vocab_size,
|
||||
config.hidden_size)
|
||||
@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config: ChatGLMConfig = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = ChatGLMModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = ChatGLMModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.output_layer.weight
|
||||
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -32,11 +32,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -91,7 +92,7 @@ class CohereMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -101,13 +102,13 @@ class CohereMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@ -123,7 +124,7 @@ class CohereAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -158,13 +159,13 @@ class CohereAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = CohereAttention(config, linear_method=linear_method)
|
||||
self.self_attn = CohereAttention(config, quant_config=quant_config)
|
||||
|
||||
self.mlp = CohereMLP(config, linear_method=linear_method)
|
||||
self.mlp = CohereMLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size),
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
@ -257,7 +258,7 @@ class CohereModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -265,7 +266,7 @@ class CohereModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
CohereDecoderLayer(config, linear_method=linear_method)
|
||||
CohereDecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = LayerNorm(param_shape=(config.hidden_size),
|
||||
@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: CohereConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
scale=config.logit_scale)
|
||||
self.model = CohereModel(config, linear_method)
|
||||
self.model = CohereModel(config, quant_config)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -9,11 +9,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -44,7 +45,7 @@ class DbrxRouter(nn.Module):
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
linear_method=None,
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -63,7 +64,7 @@ class DbrxExperts(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -165,7 +166,7 @@ class DbrxAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -183,13 +184,13 @@ class DbrxAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.attn = DbrxAttention(config, linear_method)
|
||||
self.attn = DbrxAttention(config, quant_config)
|
||||
self.norm_1 = nn.LayerNorm(self.d_model)
|
||||
self.norm_2 = nn.LayerNorm(self.d_model)
|
||||
|
||||
@ -278,11 +279,11 @@ class DbrxBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method)
|
||||
self.ffn = DbrxExperts(config, linear_method)
|
||||
self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config)
|
||||
self.ffn = DbrxExperts(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -307,7 +308,7 @@ class DbrxModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.wte = VocabParallelEmbedding(
|
||||
@ -315,7 +316,7 @@ class DbrxModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[DbrxBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||
[DbrxBlock(config, quant_config) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
|
||||
for module in self.modules():
|
||||
if hasattr(module, "bias") and isinstance(module.bias,
|
||||
@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: DbrxConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.transformer = DbrxModel(config, linear_method)
|
||||
self.transformer = DbrxModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
|
||||
@ -29,7 +29,8 @@ import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
|
||||
@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[PretrainedConfig] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
config.num_key_value_heads = max(config.num_key_value_heads_per_layer)
|
||||
delattr(config, "num_key_value_heads_per_layer")
|
||||
super().__init__(config=config,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@ -34,12 +34,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -56,18 +57,18 @@ class DeepseekMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -103,7 +104,7 @@ class DeepseekMoE(nn.Module):
|
||||
DeepseekMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
for idx in range(self.n_routed_experts)
|
||||
])
|
||||
@ -112,7 +113,7 @@ class DeepseekMoE(nn.Module):
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
@ -121,7 +122,7 @@ class DeepseekMoE(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
@ -177,7 +178,7 @@ class DeepseekAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -208,14 +209,14 @@ class DeepseekAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -251,7 +252,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -266,18 +267,18 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace
|
||||
and layer_idx % config.moe_layer_freq == 0):
|
||||
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
|
||||
self.mlp = DeepseekMoE(config=config, quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = DeepseekMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -320,7 +321,7 @@ class DeepseekModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -331,9 +332,7 @@ class DeepseekModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
DeepseekDecoderLayer(config,
|
||||
layer_idx,
|
||||
linear_method=linear_method)
|
||||
DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = DeepseekModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = DeepseekModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -32,10 +32,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -76,7 +77,7 @@ class FalconAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -115,7 +116,7 @@ class FalconAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@ -129,7 +130,7 @@ class FalconAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
self.use_rotary = config.rotary
|
||||
@ -192,7 +193,7 @@ class FalconMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -201,8 +202,8 @@ class FalconMLP(nn.Module):
|
||||
4 * hidden_size,
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
linear_method=linear_method)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config=quant_config)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
@ -212,7 +213,7 @@ class FalconMLP(nn.Module):
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
|
||||
@ -229,13 +230,13 @@ class FalconDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.self_attention = FalconAttention(config, linear_method)
|
||||
self.mlp = FalconMLP(config, linear_method)
|
||||
self.self_attention = FalconAttention(config, quant_config)
|
||||
self.mlp = FalconMLP(config, quant_config)
|
||||
self.config = config
|
||||
|
||||
if config.new_decoder_architecture:
|
||||
@ -311,7 +312,7 @@ class FalconModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -327,7 +328,7 @@ class FalconModel(nn.Module):
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
FalconDecoderLayer(config, linear_method)
|
||||
FalconDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -359,12 +360,12 @@ class FalconForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = FalconModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = FalconModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.word_embeddings.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import GeluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -77,17 +78,17 @@ class GemmaMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: Optional[str] = None,
|
||||
hidden_activation: Optional[str] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)
|
||||
|
||||
def forward(self, x):
|
||||
@ -106,7 +107,7 @@ class GemmaAttention(nn.Module):
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -135,13 +136,13 @@ class GemmaAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -187,14 +188,14 @@ class GemmaDecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
rope_theta=config.rope_theta,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = GemmaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
hidden_activation=getattr(config, "hidden_activation", None),
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -235,7 +236,7 @@ class GemmaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -245,7 +246,7 @@ class GemmaModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
GemmaDecoderLayer(config, linear_method)
|
||||
GemmaDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GemmaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config # Unused.
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = GemmaModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = GemmaModel(config, quant_config)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
|
||||
@ -27,10 +27,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -44,7 +45,7 @@ class GPT2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -61,13 +62,13 @@ class GPT2Attention(nn.Module):
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale)
|
||||
|
||||
@ -90,7 +91,7 @@ class GPT2MLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -98,15 +99,15 @@ class GPT2MLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
@ -122,7 +123,7 @@ class GPT2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -130,9 +131,9 @@ class GPT2Block(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPT2Attention(config, linear_method)
|
||||
self.attn = GPT2Attention(config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPT2MLP(inner_dim, config, linear_method)
|
||||
self.mlp = GPT2MLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -163,7 +164,7 @@ class GPT2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -174,7 +175,7 @@ class GPT2Model(nn.Module):
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([
|
||||
GPT2Block(config, linear_method)
|
||||
GPT2Block(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -203,12 +204,12 @@ class GPT2LMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPT2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPT2Model(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPT2Model(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -72,14 +73,14 @@ class GPTBigCodeAttention(nn.Module):
|
||||
total_num_heads,
|
||||
total_num_kv_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -111,7 +112,7 @@ class GPTBigMLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -119,15 +120,15 @@ class GPTBigMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
@ -143,7 +144,7 @@ class GPTBigCodeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -151,9 +152,9 @@ class GPTBigCodeBlock(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTBigCodeAttention(config, linear_method)
|
||||
self.attn = GPTBigCodeAttention(config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, linear_method)
|
||||
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -184,7 +185,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -195,7 +196,7 @@ class GPTBigCodeModel(nn.Module):
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
|
||||
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
||||
self.h = nn.ModuleList([
|
||||
GPTBigCodeBlock(config, linear_method)
|
||||
GPTBigCodeBlock(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -224,12 +225,12 @@ class GPTBigCodeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTBigCodeConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = GPTBigCodeModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = GPTBigCodeModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -44,7 +45,7 @@ class GPTJAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@ -56,13 +57,13 @@ class GPTJAttention(nn.Module):
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -105,21 +106,21 @@ class GPTJMLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
self.fc_in = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.fc_out = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
|
||||
@ -135,14 +136,14 @@ class GPTJBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = (4 * config.n_embd
|
||||
if config.n_inner is None else config.n_inner)
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
self.attn = GPTJAttention(config, linear_method)
|
||||
self.mlp = GPTJMLP(inner_dim, config, linear_method)
|
||||
self.attn = GPTJAttention(config, quant_config)
|
||||
self.mlp = GPTJMLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -169,7 +170,7 @@ class GPTJModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -179,7 +180,7 @@ class GPTJModel(nn.Module):
|
||||
self.embed_dim,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTJBlock(config, linear_method) for _ in range(config.n_layer)])
|
||||
[GPTJBlock(config, quant_config) for _ in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
def forward(
|
||||
@ -207,13 +208,13 @@ class GPTJForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTJConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
assert not config.tie_word_embeddings
|
||||
self.transformer = GPTJModel(config, linear_method)
|
||||
self.transformer = GPTJModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.n_embd,
|
||||
|
||||
@ -26,10 +26,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
@ -63,13 +64,13 @@ class GPTNeoXAttention(nn.Module):
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
bias=self.bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=self.bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
@ -105,20 +106,20 @@ class GPTNeoXMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
|
||||
@ -134,7 +135,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
@ -142,8 +143,8 @@ class GPTNeoXLayer(nn.Module):
|
||||
eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.attention = GPTNeoXAttention(config, linear_method)
|
||||
self.mlp = GPTNeoXMLP(config, linear_method)
|
||||
self.attention = GPTNeoXAttention(config, quant_config)
|
||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -182,7 +183,7 @@ class GPTNeoXModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: GPTNeoXConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -192,7 +193,7 @@ class GPTNeoXModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
GPTNeoXLayer(config, linear_method)
|
||||
GPTNeoXLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
@ -223,12 +224,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.gpt_neox = GPTNeoXModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.gpt_neox = GPTNeoXModel(config, quant_config)
|
||||
self.embed_out = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
@ -9,11 +9,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -30,17 +31,17 @@ class InternLM2MLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.w2 = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -63,7 +64,7 @@ class InternLM2Attention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -94,13 +95,13 @@ class InternLM2Attention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -150,13 +151,13 @@ class InternLMDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.feed_forward = InternLM2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attention_norm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -195,7 +196,7 @@ class InternLM2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -206,7 +207,7 @@ class InternLM2Model(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
InternLMDecoderLayer(config, linear_method)
|
||||
InternLMDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = InternLM2Model(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = InternLM2Model(config, quant_config)
|
||||
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -29,10 +29,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -68,7 +69,7 @@ class JAISAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -88,13 +89,13 @@ class JAISAttention(nn.Module):
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -128,7 +129,7 @@ class JAISMLP(nn.Module):
|
||||
self,
|
||||
intermediate_size: int,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -137,19 +138,19 @@ class JAISMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_fc2 = (ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
) if self.swiglu else None)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.act = SwiGLUActivation()
|
||||
@ -169,7 +170,7 @@ class JAISBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
@ -177,9 +178,9 @@ class JAISBlock(nn.Module):
|
||||
hidden_size)
|
||||
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = JAISAttention(config, linear_method)
|
||||
self.attn = JAISAttention(config, quant_config)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = JAISMLP(inner_dim, config, linear_method)
|
||||
self.mlp = JAISMLP(inner_dim, config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -210,7 +211,7 @@ class JAISModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -227,7 +228,7 @@ class JAISModel(nn.Module):
|
||||
else:
|
||||
self.embeddings_scale = config.mup_embeddings_scale
|
||||
self.h = nn.ModuleList([
|
||||
JAISBlock(config, linear_method)
|
||||
JAISBlock(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: JAISConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = JAISModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = JAISModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
if hasattr(config, "width_scale"):
|
||||
self.output_logits_scale = config.width_scale
|
||||
|
||||
@ -33,11 +33,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -56,17 +57,17 @@ class LlamaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QKVParallelLinear] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -89,7 +90,7 @@ class LlamaAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
@ -131,13 +132,13 @@ class LlamaAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -199,7 +200,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
@ -207,7 +208,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -248,7 +249,7 @@ class LlamaModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -264,7 +265,7 @@ class LlamaModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config, linear_method)
|
||||
LlamaDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
|
||||
self.model = LlamaModel(config, quant_config, lora_config=lora_config)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
|
||||
@ -9,8 +9,9 @@ from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VisionLanguageConfig
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
def __init__(self,
|
||||
config: "LlavaConfig",
|
||||
vision_language_config: VisionLanguageConfig,
|
||||
linear_method: Optional["LinearMethodBase"] = None) -> None:
|
||||
quant_config: Optional["QuantizationConfig"] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -83,8 +84,8 @@ class LlavaForConditionalGeneration(nn.Module):
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
|
||||
self.linear_method = linear_method
|
||||
self.language_model = LlamaModel(config.text_config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.language_model = LlamaModel(config.text_config, quant_config)
|
||||
self.unpadded_vocab_size = config.text_config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
|
||||
@ -35,12 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -84,7 +85,7 @@ class MiniCPMMoE(nn.Module):
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
@ -147,17 +148,17 @@ class MiniCPMMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -180,7 +181,7 @@ class MiniCPMAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -211,13 +212,13 @@ class MiniCPMAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -274,7 +275,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
if self.num_experts == 0:
|
||||
@ -282,7 +283,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.mlp = MiniCPMMoE(num_experts=config.num_experts,
|
||||
@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -345,7 +346,7 @@ class MiniCPMModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMDecoderLayer(config, linear_method)
|
||||
MiniCPMDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.num_experts = getattr(self.config, "num_experts", 0)
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.model = MiniCPMModel(config,
|
||||
linear_method,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@ -27,6 +27,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
@ -34,13 +35,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod,
|
||||
per_tensor_quantize)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -69,7 +70,7 @@ class MixtralMoE(nn.Module):
|
||||
intermediate_size: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
@ -79,7 +80,7 @@ class MixtralMoE(nn.Module):
|
||||
self.intermediate_size = intermediate_size // self.tp_size
|
||||
# FIXME(pcmoritz): Make this more general to support different
|
||||
# quantization schemes
|
||||
self.use_fp8 = isinstance(linear_method, Fp8LinearMethod)
|
||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -89,7 +90,7 @@ class MixtralMoE(nn.Module):
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
@ -140,10 +141,10 @@ class MixtralMoE(nn.Module):
|
||||
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
|
||||
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
|
||||
for expert in range(self.num_total_experts):
|
||||
ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize(
|
||||
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
|
||||
self.ws.data[expert, :, :])
|
||||
w2s[expert, :, :], self.w2s_scale[
|
||||
expert] = per_tensor_quantize(self.w2s.data[expert, :, :])
|
||||
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
|
||||
self.ws = nn.Parameter(ws, requires_grad=False)
|
||||
self.w2s = nn.Parameter(w2s, requires_grad=False)
|
||||
|
||||
@ -178,7 +179,7 @@ class MixtralAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -203,12 +204,12 @@ class MixtralAttention(nn.Module):
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
if isinstance(linear_method, Fp8LinearMethod):
|
||||
if isinstance(quant_config, Fp8Config):
|
||||
print_warning_once(
|
||||
"For Mixtral FP8 quantization, we currently do not quantize "
|
||||
"the attention layers until their FP8 performance is improved."
|
||||
)
|
||||
linear_method = None
|
||||
quant_config = None
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@ -216,13 +217,13 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -259,7 +260,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -272,13 +273,13 @@ class MixtralDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -318,7 +319,7 @@ class MixtralModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -334,7 +335,7 @@ class MixtralModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config, linear_method=linear_method)
|
||||
MixtralDecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -384,14 +385,13 @@ class MixtralForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config,
|
||||
linear_method,
|
||||
quant_config,
|
||||
lora_config=lora_config)
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
|
||||
@ -34,11 +34,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -55,7 +56,7 @@ class MixtralMLP(nn.Module):
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
@ -65,15 +66,15 @@ class MixtralMLP(nn.Module):
|
||||
self.w1 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.w2 = ReplicatedLinear(self.ffn_dim,
|
||||
self.hidden_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.w3 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
self.act_fn = nn.SiLU()
|
||||
@ -92,7 +93,7 @@ class MixtralMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -115,14 +116,14 @@ class MixtralMoE(nn.Module):
|
||||
MixtralMLP(self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if idx in self.expert_indicies else None
|
||||
for idx in range(self.num_total_experts)
|
||||
])
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
@ -162,7 +163,7 @@ class MixtralAttention(nn.Module):
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -193,13 +194,13 @@ class MixtralAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -249,9 +250,9 @@ class MixtralDecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.block_sparse_moe = MixtralMoE(config=config,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -291,7 +292,7 @@ class MixtralModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -302,7 +303,7 @@ class MixtralModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config, linear_method=linear_method)
|
||||
MixtralDecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = MixtralModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -11,10 +11,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -42,7 +43,7 @@ class MPTAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@ -65,7 +66,7 @@ class MPTAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if self.qk_ln:
|
||||
self.q_ln = nn.LayerNorm(self.d_model)
|
||||
@ -74,7 +75,7 @@ class MPTAttention(nn.Module):
|
||||
self.d_model,
|
||||
self.d_model,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -133,7 +134,7 @@ class MPTMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
@ -143,15 +144,15 @@ class MPTMLP(nn.Module):
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn("gelu", quant_config, intermediate_size)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -166,14 +167,14 @@ class MPTBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.d_model
|
||||
self.norm_1 = nn.LayerNorm(hidden_size)
|
||||
self.attn = MPTAttention(config, linear_method)
|
||||
self.attn = MPTAttention(config, quant_config)
|
||||
self.norm_2 = nn.LayerNorm(hidden_size)
|
||||
self.ffn = MPTMLP(config, linear_method)
|
||||
self.ffn = MPTMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -201,7 +202,7 @@ class MPTModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert config.embedding_fraction == 1.0
|
||||
@ -212,7 +213,7 @@ class MPTModel(nn.Module):
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MPTBlock(config, linear_method) for _ in range(config.n_layers)])
|
||||
[MPTBlock(config, quant_config) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
if config.no_bias:
|
||||
for module in self.modules():
|
||||
@ -246,14 +247,14 @@ class MPTForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: MPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert config.tie_word_embeddings
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.transformer = MPTModel(config, linear_method)
|
||||
self.transformer = MPTModel(config, quant_config)
|
||||
self.lm_head_weight = self.transformer.wte.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -30,11 +30,12 @@ from transformers import OlmoConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -54,7 +55,7 @@ class OlmoAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -79,7 +80,7 @@ class OlmoAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=config.attention_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Rotary embeddings.
|
||||
@ -99,7 +100,7 @@ class OlmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -129,7 +130,7 @@ class OlmoMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -141,7 +142,7 @@ class OlmoMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
# Activation function.
|
||||
@ -152,7 +153,7 @@ class OlmoMLP(nn.Module):
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
# Attention block.
|
||||
self.self_attn = OlmoAttention(config, linear_method)
|
||||
self.self_attn = OlmoAttention(config, quant_config)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = OlmoMLP(config, linear_method)
|
||||
self.mlp = OlmoMLP(config, quant_config)
|
||||
|
||||
# LayerNorm
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -216,14 +217,14 @@ class OlmoModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
OlmoDecoderLayer(config, linear_method)
|
||||
OlmoDecoderLayer(config, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = nn.LayerNorm(config.hidden_size,
|
||||
@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: OlmoConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OlmoModel(config, linear_method)
|
||||
self.model = OlmoModel(config, quant_config)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
|
||||
@ -27,11 +27,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
@ -60,7 +61,7 @@ class OPTAttention(nn.Module):
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
bias: bool = True,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@ -77,13 +78,13 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim,
|
||||
total_num_heads,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -116,7 +117,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.num_attention_heads,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
|
||||
@ -127,16 +128,16 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
@ -181,7 +182,7 @@ class OPTDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -202,7 +203,7 @@ class OPTDecoder(nn.Module):
|
||||
self.project_out = ReplicatedLinear(config.hidden_size,
|
||||
config.word_embed_proj_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.project_out = None
|
||||
|
||||
@ -210,7 +211,7 @@ class OPTDecoder(nn.Module):
|
||||
self.project_in = ReplicatedLinear(config.word_embed_proj_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.project_in = None
|
||||
|
||||
@ -226,7 +227,7 @@ class OPTDecoder(nn.Module):
|
||||
self.final_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
OPTDecoderLayer(config, linear_method)
|
||||
OPTDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@ -259,10 +260,10 @@ class OPTModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: OPTConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config, linear_method)
|
||||
self.decoder = OPTDecoder(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -279,12 +280,12 @@ class OPTForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OPTModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = OPTModel(config, quant_config)
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -13,11 +13,12 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -34,17 +35,17 @@ class OrionMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -67,7 +68,7 @@ class OrionAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -98,13 +99,13 @@ class OrionAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -154,13 +155,13 @@ class OrionDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = OrionMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -201,7 +202,7 @@ class OrionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -212,7 +213,7 @@ class OrionModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
OrionDecoderLayer(config, linear_method)
|
||||
OrionDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = OrionModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = OrionModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -45,10 +45,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -62,7 +63,7 @@ class PhiAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -80,12 +81,12 @@ class PhiAttention(nn.Module):
|
||||
self.head_size,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
@ -125,7 +126,7 @@ class PhiMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
|
||||
n_inner = getattr(config, "n_inner", None)
|
||||
@ -134,14 +135,14 @@ class PhiMLP(nn.Module):
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
n_inner,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
n_inner,
|
||||
config.hidden_size,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -155,12 +156,12 @@ class PhiLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
self.self_attn = PhiAttention(config, linear_method)
|
||||
self.mlp = PhiMLP(config, linear_method)
|
||||
self.self_attn = PhiAttention(config, quant_config)
|
||||
self.mlp = PhiMLP(config, quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -186,14 +187,14 @@ class PhiModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
PhiLayer(config, linear_method)
|
||||
PhiLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.final_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -225,12 +226,12 @@ class PhiForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.model = PhiModel(config, linear_method)
|
||||
self.model = PhiModel(config, quant_config)
|
||||
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
|
||||
@ -14,11 +14,12 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -35,17 +36,17 @@ class QWenMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -67,7 +68,7 @@ class QWenAttention(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -83,13 +84,13 @@ class QWenAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
@ -122,7 +123,7 @@ class QWenBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -134,13 +135,13 @@ class QWenBlock(nn.Module):
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -174,7 +175,7 @@ class QWenModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -185,7 +186,7 @@ class QWenModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList([
|
||||
QWenBlock(config, linear_method)
|
||||
QWenBlock(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.transformer = QWenModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.transformer = QWenModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -33,11 +33,12 @@ from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -54,17 +55,17 @@ class Qwen2MLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -86,7 +87,7 @@ class Qwen2Attention(nn.Module):
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
use_sliding_window: bool = False,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -117,13 +118,13 @@ class Qwen2Attention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -159,7 +160,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
layer_idx: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -174,13 +175,13 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
use_sliding_window=use_sliding_window,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
sliding_window=config.sliding_window)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -221,7 +222,7 @@ class Qwen2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -233,7 +234,7 @@ class Qwen2Model(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2DecoderLayer(config, layer_idx, linear_method)
|
||||
Qwen2DecoderLayer(config, layer_idx, quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
del lora_config
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = Qwen2Model(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(config, quant_config)
|
||||
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head_weight = self.model.embed_tokens.weight
|
||||
|
||||
@ -36,12 +36,13 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -58,18 +59,18 @@ class Qwen2MoeMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@ -105,7 +106,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
Qwen2MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False)
|
||||
for idx in range(self.n_routed_experts)
|
||||
])
|
||||
@ -114,13 +115,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.n_routed_experts,
|
||||
bias=False,
|
||||
linear_method=None)
|
||||
quant_config=None)
|
||||
if config.shared_expert_intermediate_size > 0:
|
||||
self.shared_expert = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.shared_expert_intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
)
|
||||
else:
|
||||
@ -186,7 +187,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -217,14 +218,14 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=True,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -260,7 +261,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
layer_idx: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -275,18 +276,18 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
if (config.num_experts is not None
|
||||
and (layer_idx + 1) % config.decoder_sparse_step == 0):
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.mlp = Qwen2MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -338,9 +339,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen2MoeDecoderLayer(config,
|
||||
layer_idx,
|
||||
linear_method=linear_method)
|
||||
Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config)
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = Qwen2MoeModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2MoeModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -28,11 +28,12 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -46,7 +47,7 @@ class StablelmMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -54,7 +55,7 @@ class StablelmMLP(nn.Module):
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
config.hidden_size, [config.intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
@ -71,7 +72,7 @@ class StablelmAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -109,11 +110,11 @@ class StablelmAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_key_value_heads,
|
||||
self.qkv_bias,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.rotary_ndims,
|
||||
@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = StablelmAttention(config)
|
||||
self.mlp = StablelmMLP(config, linear_method)
|
||||
self.mlp = StablelmMLP(config, quant_config)
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
getattr(config, "layer_norm_eps", 1e-05))
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
|
||||
@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
StablelmDecoderLayer(config, linear_method)
|
||||
StablelmDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
norm_eps = getattr(config, "norm_eps",
|
||||
@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = StableLMEpochModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = StableLMEpochModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
@ -28,10 +28,11 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearMethodBase,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@ -79,13 +80,13 @@ class Starcoder2Attention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=self.use_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=self.use_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
@ -121,21 +122,21 @@ class Starcoder2MLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=config.use_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=config.use_bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
quant_config = getattr(linear_method, "quant_config", None)
|
||||
quant_config = getattr(quant_config, "quant_config", None)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
|
||||
@ -150,12 +151,11 @@ class Starcoder2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Starcoder2Attention(config,
|
||||
linear_method=linear_method)
|
||||
self.mlp = Starcoder2MLP(config, linear_method=linear_method)
|
||||
self.self_attn = Starcoder2Attention(config, quant_config=quant_config)
|
||||
self.mlp = Starcoder2MLP(config, quant_config=quant_config)
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
|
||||
@ -192,7 +192,7 @@ class Starcoder2Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -202,7 +202,7 @@ class Starcoder2Model(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size)
|
||||
self.layers = nn.ModuleList([
|
||||
Starcoder2DecoderLayer(config, linear_method=linear_method)
|
||||
Starcoder2DecoderLayer(config, quant_config=quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
|
||||
@ -227,10 +227,10 @@ class Starcoder2ForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: Starcoder2Config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = Starcoder2Model(config, linear_method=linear_method)
|
||||
self.model = Starcoder2Model(config, quant_config=quant_config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if config.tie_word_embeddings:
|
||||
|
||||
@ -31,11 +31,12 @@ from vllm.config import LoRAConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
MergedColumnParallelLinear,
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -52,17 +53,17 @@ class XverseMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -85,7 +86,7 @@ class XverseAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
@ -112,13 +113,13 @@ class XverseAttention(nn.Module):
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -171,7 +172,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "bias", False),
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
@ -179,7 +180,7 @@ class XverseDecoderLayer(nn.Module):
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
linear_method=linear_method,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -220,7 +221,7 @@ class XverseModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -236,7 +237,7 @@ class XverseModel(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
XverseDecoderLayer(config, linear_method)
|
||||
XverseDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
lora_config=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = XverseModel(config, linear_method)
|
||||
self.quant_config = quant_config
|
||||
self.model = XverseModel(config, quant_config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.sampler = Sampler()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user