mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:45:01 +08:00
[Core] Allow disabling TP sharding for parallel Linear layer (#23024)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
6432739ef1
commit
53b19ccdd5
@ -223,6 +223,7 @@ class LinearBase(CustomOp):
|
|||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
prefix: Prefix for parameter names.
|
prefix: Prefix for parameter names.
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: If true, tensor parallelism will be disabled for this layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -235,6 +236,7 @@ class LinearBase(CustomOp):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -254,6 +256,17 @@ class LinearBase(CustomOp):
|
|||||||
self.quant_method = quant_config.get_quant_method(self,
|
self.quant_method = quant_config.get_quant_method(self,
|
||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
self.return_bias = return_bias
|
self.return_bias = return_bias
|
||||||
|
self.disable_tp = disable_tp
|
||||||
|
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||||
|
if not disable_tp else 0)
|
||||||
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||||
|
if not disable_tp else 1)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
for param in self.parameters():
|
||||||
|
if isinstance(param, BasevLLMParameter):
|
||||||
|
param.tp_rank = self.tp_rank
|
||||||
|
param.tp_size = self.tp_size
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("replicated_linear")
|
@CustomOp.register("replicated_linear")
|
||||||
@ -270,6 +283,7 @@ class ReplicatedLinear(LinearBase):
|
|||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: Take no effect for replicated linear layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -283,26 +297,21 @@ class ReplicatedLinear(LinearBase):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
# If MergedReplicatedLinear, use output size of each partition.
|
|
||||||
if hasattr(self, "output_sizes"):
|
|
||||||
self.output_partition_sizes = self.output_sizes
|
|
||||||
else:
|
|
||||||
self.output_partition_sizes = [output_size]
|
|
||||||
|
|
||||||
super().__init__(input_size,
|
super().__init__(input_size,
|
||||||
output_size,
|
output_size,
|
||||||
skip_bias_add,
|
skip_bias_add,
|
||||||
params_dtype,
|
params_dtype,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
return_bias=return_bias)
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp)
|
||||||
|
|
||||||
# All the linear layer supports quant method.
|
# All the linear layer supports quant method.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.input_size,
|
self.input_size, [self.output_size],
|
||||||
self.output_partition_sizes,
|
|
||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
self.params_dtype,
|
self.params_dtype,
|
||||||
@ -358,74 +367,6 @@ class ReplicatedLinear(LinearBase):
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
class MergedReplicatedLinear(ReplicatedLinear):
|
|
||||||
"""Replicated linear layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_size: input dimension of the linear layer.
|
|
||||||
output_sizes: list of output dimensions of the linear layer.
|
|
||||||
bias: If true, add bias.
|
|
||||||
skip_bias_add: If true, skip adding bias but instead return it.
|
|
||||||
params_dtype: Data type for the parameters.
|
|
||||||
quant_config: Quantization configure.
|
|
||||||
prefix: The name of the layer in the state dict, including all parents
|
|
||||||
(e.g. model.layers.0.qkv_proj)
|
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_sizes: list[int],
|
|
||||||
bias: bool = True,
|
|
||||||
skip_bias_add: bool = False,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "",
|
|
||||||
*,
|
|
||||||
return_bias: bool = True,
|
|
||||||
):
|
|
||||||
self.output_sizes = output_sizes
|
|
||||||
super().__init__(input_size,
|
|
||||||
sum(output_sizes),
|
|
||||||
bias,
|
|
||||||
skip_bias_add,
|
|
||||||
params_dtype,
|
|
||||||
quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
return_bias=return_bias)
|
|
||||||
|
|
||||||
def weight_loader(self,
|
|
||||||
param: Union[Parameter, BasevLLMParameter],
|
|
||||||
loaded_weight: torch.Tensor,
|
|
||||||
loaded_shard_id: Optional[int] = None):
|
|
||||||
assert loaded_shard_id is not None
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
|
||||||
|
|
||||||
if isinstance(param, BlockQuantScaleParameter):
|
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
|
||||||
Fp8LinearMethod, Fp8MoEMethod)
|
|
||||||
assert self.quant_method is not None
|
|
||||||
assert isinstance(self.quant_method,
|
|
||||||
(Fp8LinearMethod, Fp8MoEMethod))
|
|
||||||
weight_block_size = self.quant_method.quant_config.weight_block_size
|
|
||||||
assert weight_block_size is not None
|
|
||||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
|
||||||
shard_offset = (
|
|
||||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
|
||||||
block_n)
|
|
||||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
|
||||||
block_n)
|
|
||||||
elif isinstance(param, PerTensorScaleParameter):
|
|
||||||
shard_offset = loaded_shard_id
|
|
||||||
shard_size = 1
|
|
||||||
else:
|
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
|
||||||
shard_size = self.output_sizes[loaded_shard_id]
|
|
||||||
|
|
||||||
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
|
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("column_parallel_linear")
|
@CustomOp.register("column_parallel_linear")
|
||||||
class ColumnParallelLinear(LinearBase):
|
class ColumnParallelLinear(LinearBase):
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
@ -449,6 +390,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
the list would be size 3.
|
the list would be size 3.
|
||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -464,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||||
|
if not disable_tp else 0)
|
||||||
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||||
|
if not disable_tp else 1)
|
||||||
self.input_size_per_partition = input_size
|
self.input_size_per_partition = input_size
|
||||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||||
self.output_partition_sizes = [self.output_size_per_partition]
|
self.output_partition_sizes = [self.output_size_per_partition]
|
||||||
@ -483,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
params_dtype,
|
params_dtype,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix,
|
prefix,
|
||||||
return_bias=return_bias)
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp)
|
||||||
|
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
|
||||||
@ -512,8 +460,6 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
|
|
||||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||||
|
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
@ -554,7 +500,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||||
|
loaded_weight: torch.Tensor):
|
||||||
# Special case for loading scales off disk, which often do not
|
# Special case for loading scales off disk, which often do not
|
||||||
# have a shape (such as in the case of AutoFP8).
|
# have a shape (such as in the case of AutoFP8).
|
||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
@ -570,7 +517,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
if self.gather_output:
|
if self.gather_output and self.tp_size > 1:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
else:
|
else:
|
||||||
@ -584,7 +531,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
s = f"in_features={self.input_size}"
|
s = f"in_features={self.input_size}"
|
||||||
s += f", output_features={self.output_size_per_partition}"
|
s += f", output_features={self.output_size_per_partition}"
|
||||||
s += f", bias={self.bias is not None}"
|
s += f", bias={self.bias is not None}"
|
||||||
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
s += f", tp_size={self.tp_size}"
|
||||||
s += f", gather_output={self.gather_output}"
|
s += f", gather_output={self.gather_output}"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@ -611,6 +558,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: If true, all weights matrix won't be sharded, this layer
|
||||||
|
will be treated as a "Replicated" MergedLinear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -625,10 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
if not disable_tp else 1)
|
||||||
|
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||||
|
if not disable_tp else 0)
|
||||||
|
|
||||||
assert all(output_size % self.tp_size == 0
|
assert all(output_size % self.tp_size == 0
|
||||||
for output_size in output_sizes)
|
for output_size in output_sizes)
|
||||||
@ -640,7 +592,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
return_bias=return_bias)
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp)
|
||||||
|
|
||||||
def weight_loader(self,
|
def weight_loader(self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
@ -832,8 +785,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
|
|
||||||
assert loaded_shard_id < len(self.output_sizes)
|
assert loaded_shard_id < len(self.output_sizes)
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
if isinstance(param, BlockQuantScaleParameter):
|
if isinstance(param, BlockQuantScaleParameter):
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (
|
from vllm.model_executor.layers.quantization.fp8 import (
|
||||||
Fp8LinearMethod, Fp8MoEMethod)
|
Fp8LinearMethod, Fp8MoEMethod)
|
||||||
@ -845,17 +796,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
block_n, _ = weight_block_size[0], weight_block_size[1]
|
block_n, _ = weight_block_size[0], weight_block_size[1]
|
||||||
shard_offset = (
|
shard_offset = (
|
||||||
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
|
||||||
block_n) // tp_size
|
block_n) // self.tp_size
|
||||||
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
|
||||||
block_n // tp_size)
|
block_n // self.tp_size)
|
||||||
else:
|
else:
|
||||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
shard_offset = sum(
|
||||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||||
|
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||||
|
|
||||||
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
param.load_merged_column_weight(loaded_weight=loaded_weight,
|
||||||
shard_id=loaded_shard_id,
|
shard_id=loaded_shard_id,
|
||||||
shard_offset=shard_offset,
|
shard_offset=shard_offset,
|
||||||
shard_size=shard_size)
|
shard_size=shard_size,
|
||||||
|
tp_rank=self.tp_rank)
|
||||||
|
|
||||||
|
|
||||||
class QKVParallelLinear(ColumnParallelLinear):
|
class QKVParallelLinear(ColumnParallelLinear):
|
||||||
@ -883,6 +836,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -898,6 +852,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -906,7 +861,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
total_num_kv_heads = total_num_heads
|
total_num_kv_heads = total_num_heads
|
||||||
self.total_num_kv_heads = total_num_kv_heads
|
self.total_num_kv_heads = total_num_kv_heads
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = (get_tensor_model_parallel_world_size()
|
||||||
|
if not disable_tp else 1)
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
if tp_size >= self.total_num_kv_heads:
|
if tp_size >= self.total_num_kv_heads:
|
||||||
self.num_kv_heads = 1
|
self.num_kv_heads = 1
|
||||||
@ -932,7 +888,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
return_bias=return_bias)
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp)
|
||||||
|
|
||||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
shard_offset_mapping = {
|
shard_offset_mapping = {
|
||||||
@ -993,10 +950,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_shard_id: Optional[str] = None):
|
loaded_shard_id: Optional[str] = None):
|
||||||
if loaded_shard_id is None: # special case for certain models
|
if loaded_shard_id is None: # special case for certain models
|
||||||
if isinstance(param, PerTensorScaleParameter):
|
if isinstance(param, PerTensorScaleParameter):
|
||||||
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
|
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||||
|
shard_id=0,
|
||||||
|
tp_rank=self.tp_rank)
|
||||||
return
|
return
|
||||||
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
||||||
param.load_qkv_weight(loaded_weight=loaded_weight)
|
param.load_qkv_weight(loaded_weight=loaded_weight,
|
||||||
|
tp_rank=self.tp_rank)
|
||||||
return
|
return
|
||||||
# TODO: @dsikka - move to parameter.py
|
# TODO: @dsikka - move to parameter.py
|
||||||
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
||||||
@ -1020,7 +980,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
num_heads=self.num_kv_head_replicas,
|
num_heads=self.num_kv_head_replicas,
|
||||||
shard_id=loaded_shard_id,
|
shard_id=loaded_shard_id,
|
||||||
shard_offset=shard_offset,
|
shard_offset=shard_offset,
|
||||||
shard_size=shard_size)
|
shard_size=shard_size,
|
||||||
|
tp_rank=self.tp_rank)
|
||||||
|
|
||||||
def weight_loader(self,
|
def weight_loader(self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
@ -1226,6 +1187,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.down_proj)
|
(e.g. model.layers.0.down_proj)
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
|
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1241,10 +1203,13 @@ class RowParallelLinear(LinearBase):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
*,
|
*,
|
||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
# Divide the weight matrix along the first dimension.
|
# Divide the weight matrix along the first dimension.
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = (get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
if not disable_tp else 0)
|
||||||
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
||||||
|
if not disable_tp else 1)
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
self.output_size_per_partition = output_size
|
self.output_size_per_partition = output_size
|
||||||
self.output_partition_sizes = [output_size]
|
self.output_partition_sizes = [output_size]
|
||||||
@ -1255,7 +1220,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
params_dtype,
|
params_dtype,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix,
|
prefix,
|
||||||
return_bias=return_bias)
|
return_bias=return_bias,
|
||||||
|
disable_tp=disable_tp)
|
||||||
|
|
||||||
self.input_is_parallel = input_is_parallel
|
self.input_is_parallel = input_is_parallel
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@ -1339,10 +1305,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
splitted_input = split_tensor_along_last_dim(
|
splitted_input = split_tensor_along_last_dim(
|
||||||
input_, num_partitions=self.tp_size)
|
input_, num_partitions=self.tp_size)
|
||||||
input_parallel = splitted_input[tp_rank].contiguous()
|
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|||||||
@ -69,6 +69,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Store all module names (from transformers) that support
|
# Store all module names (from transformers) that support
|
||||||
# BNB quantization.
|
# BNB quantization.
|
||||||
self.target_modules: list[str] = []
|
self.target_modules: list[str] = []
|
||||||
|
self.tp_disabled_modules: list[str] = []
|
||||||
# Store the mapping of expert parameters for MoE models.
|
# Store the mapping of expert parameters for MoE models.
|
||||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||||
# mapping weight names from transformers to vllm.
|
# mapping weight names from transformers to vllm.
|
||||||
@ -322,14 +323,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
quant_state_dict) -> Generator:
|
quant_state_dict) -> Generator:
|
||||||
from bitsandbytes.functional import quantize_4bit
|
from bitsandbytes.functional import quantize_4bit
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
global_tp_size = get_tensor_model_parallel_world_size()
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
global_tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
for (
|
for (
|
||||||
org_weight_name,
|
org_weight_name,
|
||||||
mapped_weight_name,
|
mapped_weight_name,
|
||||||
weight_tensor,
|
weight_tensor,
|
||||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||||
|
|
||||||
|
# override tp_size and tp_rank if the module has disabled TP
|
||||||
|
if any(tp_disabled_module in mapped_weight_name
|
||||||
|
for tp_disabled_module in self.tp_disabled_modules):
|
||||||
|
tp_size = 1
|
||||||
|
tp_rank = 0
|
||||||
|
else:
|
||||||
|
tp_size = global_tp_size
|
||||||
|
tp_rank = global_tp_rank
|
||||||
|
|
||||||
if any(target_module in mapped_weight_name
|
if any(target_module in mapped_weight_name
|
||||||
for target_module in self.target_modules
|
for target_module in self.target_modules
|
||||||
) and mapped_weight_name.endswith(".weight"):
|
) and mapped_weight_name.endswith(".weight"):
|
||||||
@ -418,12 +429,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
# Map vllm's names to transformers's names.
|
# Map vllm's names to transformers's names.
|
||||||
rep_name, sub_modules = modules_info
|
rep_name, sub_modules = modules_info
|
||||||
for sub_name in sub_modules:
|
for sub_name in sub_modules:
|
||||||
self.target_modules.append(
|
new_name = name.replace(rep_name, sub_name)
|
||||||
name.replace(rep_name, sub_name))
|
self.target_modules.append(new_name)
|
||||||
|
if module.disable_tp:
|
||||||
|
self.tp_disabled_modules.append(new_name)
|
||||||
# Add original module name even if the module has stacked map,
|
# Add original module name even if the module has stacked map,
|
||||||
# in case model has a mixture of disk-merged and disk-split
|
# in case model has a mixture of disk-merged and disk-split
|
||||||
# weights with same last name.
|
# weights with same last name.
|
||||||
self.target_modules.append(name)
|
self.target_modules.append(name)
|
||||||
|
if module.disable_tp:
|
||||||
|
self.tp_disabled_modules.append(name)
|
||||||
elif isinstance(module, FusedMoE) and hasattr(
|
elif isinstance(module, FusedMoE) and hasattr(
|
||||||
module.quant_method, "quant_config"):
|
module.quant_method, "quant_config"):
|
||||||
# TODO: support FusedMoE with prequant and 8bit.
|
# TODO: support FusedMoE with prequant and 8bit.
|
||||||
|
|||||||
@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
MergedReplicatedLinear,
|
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@ -435,12 +434,13 @@ class DeepseekV2MLAAttention(nn.Module):
|
|||||||
self.max_position_embeddings = max_position_embeddings
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
self.fused_qkv_a_proj = MergedReplicatedLinear(
|
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fused_qkv_a_proj")
|
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||||
|
disable_tp=True)
|
||||||
else:
|
else:
|
||||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
|
|||||||
@ -51,14 +51,10 @@ from vllm.distributed import utils as dist_utils
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import SamplingMetadata
|
from vllm.model_executor import SamplingMetadata
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
# yapf: disable
|
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
MergedReplicatedLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
# yapf: enable
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@ -174,20 +170,22 @@ class Glm4vVisionMLP(nn.Module):
|
|||||||
use_data_parallel: bool = False,
|
use_data_parallel: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cls_gate_up = (MergedReplicatedLinear
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
if use_data_parallel else MergedColumnParallelLinear)
|
input_size=in_features,
|
||||||
self.gate_up_proj = cls_gate_up(input_size=in_features,
|
|
||||||
output_sizes=[hidden_features] * 2,
|
output_sizes=[hidden_features] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj")
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
cls_down = (ReplicatedLinear
|
disable_tp=use_data_parallel,
|
||||||
if use_data_parallel else RowParallelLinear)
|
)
|
||||||
self.down_proj = cls_down(hidden_features,
|
self.down_proj = RowParallelLinear(
|
||||||
|
hidden_features,
|
||||||
in_features,
|
in_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.down_proj")
|
prefix=f"{prefix}.down_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@ -234,30 +232,13 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
# Per attention head and per partition values.
|
# Per attention head and per partition values.
|
||||||
self.tp_size = (1 if use_data_parallel else
|
self.tp_size = (1 if use_data_parallel else
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
self.tp_rank = (0 if use_data_parallel else
|
||||||
|
parallel_state.get_tensor_model_parallel_rank())
|
||||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||||
projection_size, num_heads)
|
projection_size, num_heads)
|
||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, self.tp_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
if use_data_parallel:
|
|
||||||
self.qkv = ReplicatedLinear(
|
|
||||||
input_size=embed_dim,
|
|
||||||
output_size=3 * projection_size,
|
|
||||||
bias=False,
|
|
||||||
quant_config=quant_config,
|
|
||||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
|
||||||
prefix=f"{prefix}.qkv_proj"
|
|
||||||
if quant_config else f"{prefix}.qkv",
|
|
||||||
)
|
|
||||||
self.proj = ReplicatedLinear(
|
|
||||||
input_size=projection_size,
|
|
||||||
output_size=embed_dim,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.proj",
|
|
||||||
bias=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.qkv = QKVParallelLinear(
|
self.qkv = QKVParallelLinear(
|
||||||
hidden_size=embed_dim,
|
hidden_size=embed_dim,
|
||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
@ -266,8 +247,8 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
|
||||||
prefix=f"{prefix}.qkv_proj"
|
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
|
||||||
if quant_config else f"{prefix}.qkv",
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.proj = RowParallelLinear(
|
self.proj = RowParallelLinear(
|
||||||
input_size=projection_size,
|
input_size=projection_size,
|
||||||
@ -275,6 +256,7 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
bias=False,
|
bias=False,
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
@ -494,15 +476,6 @@ class Glm4vPatchMerger(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = d_model
|
self.hidden_size = d_model
|
||||||
if use_data_parallel:
|
|
||||||
self.proj = ReplicatedLinear(
|
|
||||||
input_size=self.hidden_size,
|
|
||||||
output_size=self.hidden_size,
|
|
||||||
bias=bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.proj",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.proj = ColumnParallelLinear(
|
self.proj = ColumnParallelLinear(
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
@ -510,25 +483,24 @@ class Glm4vPatchMerger(nn.Module):
|
|||||||
gather_output=True,
|
gather_output=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
|
||||||
cls_gate_up = (MergedReplicatedLinear
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
if use_data_parallel else MergedColumnParallelLinear)
|
|
||||||
self.gate_up_proj = cls_gate_up(
|
|
||||||
input_size=self.hidden_size,
|
input_size=self.hidden_size,
|
||||||
output_sizes=[context_dim] * 2,
|
output_sizes=[context_dim] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj",
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
cls_down = (ReplicatedLinear
|
self.down_proj = RowParallelLinear(
|
||||||
if use_data_parallel else RowParallelLinear)
|
|
||||||
self.down_proj = cls_down(
|
|
||||||
context_dim,
|
context_dim,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
self.extra_activation_func = nn.GELU()
|
self.extra_activation_func = nn.GELU()
|
||||||
|
|||||||
@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
MergedReplicatedLinear,
|
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
use_data_parallel: bool = False):
|
use_data_parallel: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
MergedColumnParallelLinear)
|
|
||||||
self.gate_up_proj = cls_gate_up_proj(
|
|
||||||
input_size=in_features,
|
input_size=in_features,
|
||||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj")
|
prefix=f"{prefix}.gate_up_proj",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
cls_down_proj = (ReplicatedLinear
|
self.down_proj = RowParallelLinear(hidden_features,
|
||||||
if use_data_parallel else RowParallelLinear)
|
|
||||||
self.down_proj = cls_down_proj(hidden_features,
|
|
||||||
in_features,
|
in_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.down_proj")
|
prefix=f"{prefix}.down_proj",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
self.act_fn = act_fn
|
self.act_fn = act_fn
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@ -243,15 +240,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||||
num_heads, self.tp_size)
|
num_heads, self.tp_size)
|
||||||
|
|
||||||
if use_data_parallel:
|
|
||||||
self.qkv = ReplicatedLinear(embed_dim,
|
|
||||||
self.hidden_size_per_attention_head *
|
|
||||||
3 * num_heads,
|
|
||||||
bias=True,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=f"{prefix}.qkv")
|
|
||||||
|
|
||||||
else:
|
|
||||||
self.qkv = QKVParallelLinear(
|
self.qkv = QKVParallelLinear(
|
||||||
hidden_size=embed_dim,
|
hidden_size=embed_dim,
|
||||||
head_size=self.hidden_size_per_attention_head,
|
head_size=self.hidden_size_per_attention_head,
|
||||||
@ -259,14 +247,14 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
total_num_kv_heads=num_heads,
|
total_num_kv_heads=num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv")
|
prefix=f"{prefix}.qkv",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
cls_proj = (ReplicatedLinear
|
self.proj = RowParallelLinear(input_size=projection_size,
|
||||||
if use_data_parallel else RowParallelLinear)
|
|
||||||
self.proj = cls_proj(input_size=projection_size,
|
|
||||||
output_size=embed_dim,
|
output_size=embed_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj")
|
prefix=f"{prefix}.proj",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
# Detect attention implementation.
|
# Detect attention implementation.
|
||||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||||
|
|||||||
@ -21,7 +21,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
@ -667,35 +666,21 @@ class Step3VisionAttention(nn.Module):
|
|||||||
|
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
|
||||||
if use_data_parallel:
|
|
||||||
self.qkv_proj = ReplicatedLinear(
|
|
||||||
self.embed_dim,
|
|
||||||
3 * self.q_size,
|
|
||||||
bias=True,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
self.out_proj = ReplicatedLinear(
|
|
||||||
self.total_num_heads * self.head_dim,
|
|
||||||
self.embed_dim,
|
|
||||||
bias=True,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=prefix,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.total_num_heads,
|
self.total_num_heads,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix,
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
disable_tp=use_data_parallel,
|
||||||
)
|
)
|
||||||
self.out_proj = RowParallelLinear(self.embed_dim,
|
self.out_proj = RowParallelLinear(self.embed_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=f"{prefix}.out_proj",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads,
|
return tensor.view(bsz, seq_len, self.num_heads,
|
||||||
@ -740,20 +725,18 @@ class Step3VisionMLP(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.activation_fn = get_act_fn(config.hidden_act)
|
self.activation_fn = get_act_fn(config.hidden_act)
|
||||||
cls_fc1 = (ReplicatedLinear
|
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||||
if use_data_parallel else ColumnParallelLinear)
|
|
||||||
self.fc1 = cls_fc1(config.hidden_size,
|
|
||||||
config.intermediate_size,
|
config.intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=f"{prefix}.fc1",
|
||||||
cls_fc2 = (ReplicatedLinear
|
disable_tp=use_data_parallel)
|
||||||
if use_data_parallel else RowParallelLinear)
|
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||||
self.fc2 = cls_fc2(config.intermediate_size,
|
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=f"{prefix}.fc2",
|
||||||
|
disable_tp=use_data_parallel)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.fc1(hidden_states)
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
|
|||||||
@ -57,6 +57,8 @@ class BasevLLMParameter(Parameter):
|
|||||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||||
|
|
||||||
self._weight_loader = weight_loader
|
self._weight_loader = weight_loader
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weight_loader(self):
|
def weight_loader(self):
|
||||||
@ -116,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
return self._output_dim
|
return self._output_dim
|
||||||
|
|
||||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard_size = self.data.shape[self.output_dim]
|
shard_size = self.data.shape[self.output_dim]
|
||||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||||
tp_rank * shard_size, shard_size)
|
self.tp_rank * shard_size,
|
||||||
|
shard_size)
|
||||||
assert self.data.shape == loaded_weight.shape
|
assert self.data.shape == loaded_weight.shape
|
||||||
self.data.copy_(loaded_weight)
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
@ -127,6 +129,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
|
|
||||||
shard_offset = kwargs.get("shard_offset")
|
shard_offset = kwargs.get("shard_offset")
|
||||||
shard_size = kwargs.get("shard_size")
|
shard_size = kwargs.get("shard_size")
|
||||||
|
|
||||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||||
if isinstance(
|
if isinstance(
|
||||||
self,
|
self,
|
||||||
@ -137,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
|
|
||||||
param_data = self.data
|
param_data = self.data
|
||||||
|
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||||
tp_rank * shard_size, shard_size)
|
self.tp_rank * shard_size,
|
||||||
|
shard_size)
|
||||||
assert param_data.shape == loaded_weight.shape
|
assert param_data.shape == loaded_weight.shape
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
@ -161,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|||||||
shard_offset=shard_offset, shard_size=shard_size)
|
shard_offset=shard_offset, shard_size=shard_size)
|
||||||
|
|
||||||
param_data = self.data
|
param_data = self.data
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
|
||||||
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
|
num_heads)
|
||||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||||
@ -189,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter):
|
|||||||
return self._input_dim
|
return self._input_dim
|
||||||
|
|
||||||
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
shard_size = self.data.shape[self.input_dim]
|
shard_size = self.data.shape[self.input_dim]
|
||||||
loaded_weight = loaded_weight.narrow(self.input_dim,
|
loaded_weight = loaded_weight.narrow(self.input_dim,
|
||||||
tp_rank * shard_size, shard_size)
|
self.tp_rank * shard_size,
|
||||||
|
shard_size)
|
||||||
|
|
||||||
if len(loaded_weight.shape) == 0:
|
if len(loaded_weight.shape) == 0:
|
||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
@ -414,9 +417,6 @@ class SharedWeightParameter(BasevLLMParameter):
|
|||||||
"weight_loader": self._fake_weight_loader
|
"weight_loader": self._fake_weight_loader
|
||||||
}
|
}
|
||||||
|
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
raise NotImplementedError(f"{self.__class__.__name__} does not "
|
||||||
"currently support tensor parallelism")
|
"currently support tensor parallelism")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user