[Core] Allow disabling TP sharding for parallel Linear layer (#23024)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Isotr0py 2025-09-06 13:53:58 +08:00 committed by GitHub
parent 6432739ef1
commit 53b19ccdd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 203 additions and 280 deletions

View File

@ -223,6 +223,7 @@ class LinearBase(CustomOp):
quant_config: Quantization configure.
prefix: Prefix for parameter names.
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, tensor parallelism will be disabled for this layer.
"""
def __init__(
@ -235,6 +236,7 @@ class LinearBase(CustomOp):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
super().__init__()
@ -254,6 +256,17 @@ class LinearBase(CustomOp):
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
self.disable_tp = disable_tp
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
def __post_init__(self):
for param in self.parameters():
if isinstance(param, BasevLLMParameter):
param.tp_rank = self.tp_rank
param.tp_size = self.tp_size
@CustomOp.register("replicated_linear")
@ -270,6 +283,7 @@ class ReplicatedLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: Take no effect for replicated linear layers.
"""
def __init__(
@ -283,26 +297,21 @@ class ReplicatedLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
self.output_partition_sizes,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
@ -358,74 +367,6 @@ class ReplicatedLinear(LinearBase):
return s
class MergedReplicatedLinear(ReplicatedLinear):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias)
def weight_loader(self,
param: Union[Parameter, BasevLLMParameter],
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
assert loaded_shard_id is not None
assert loaded_shard_id < len(self.output_sizes)
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
assert self.quant_method is not None
assert isinstance(self.quant_method,
(Fp8LinearMethod, Fp8MoEMethod))
weight_block_size = self.quant_method.quant_config.weight_block_size
assert weight_block_size is not None
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n)
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n)
elif isinstance(param, PerTensorScaleParameter):
shard_offset = loaded_shard_id
shard_size = 1
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
param.data[shard_offset:shard_offset + shard_size] = loaded_weight
@CustomOp.register("column_parallel_linear")
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
@ -448,7 +389,9 @@ class ColumnParallelLinear(LinearBase):
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@ -464,9 +407,13 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
@ -483,7 +430,8 @@ class ColumnParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.gather_output = gather_output
@ -512,8 +460,6 @@ class ColumnParallelLinear(LinearBase):
else:
self.register_parameter("bias", None)
self.tp_rank = get_tensor_model_parallel_rank()
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
@ -554,7 +500,8 @@ class ColumnParallelLinear(LinearBase):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
@ -570,7 +517,7 @@ class ColumnParallelLinear(LinearBase):
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
@ -584,7 +531,7 @@ class ColumnParallelLinear(LinearBase):
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", tp_size={self.tp_size}"
s += f", gather_output={self.gather_output}"
return s
@ -611,6 +558,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def __init__(
@ -625,10 +574,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.output_sizes = output_sizes
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
assert all(output_size % self.tp_size == 0
for output_size in output_sizes)
@ -640,7 +592,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def weight_loader(self,
param: Parameter,
@ -832,8 +785,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
assert loaded_shard_id < len(self.output_sizes)
tp_size = get_tensor_model_parallel_world_size()
if isinstance(param, BlockQuantScaleParameter):
from vllm.model_executor.layers.quantization.fp8 import (
Fp8LinearMethod, Fp8MoEMethod)
@ -845,17 +796,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
block_n, _ = weight_block_size[0], weight_block_size[1]
shard_offset = (
(sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) //
block_n) // tp_size
block_n) // self.tp_size
shard_size = ((self.output_sizes[loaded_shard_id] + block_n - 1) //
block_n // tp_size)
block_n // self.tp_size)
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
shard_offset = sum(
self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
class QKVParallelLinear(ColumnParallelLinear):
@ -883,6 +836,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@ -898,6 +852,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
@ -906,7 +861,8 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
@ -932,7 +888,8 @@ class QKVParallelLinear(ColumnParallelLinear):
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
@ -993,10 +950,13 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight,
shard_id=0,
tp_rank=self.tp_rank)
return
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
param.load_qkv_weight(loaded_weight=loaded_weight,
tp_rank=self.tp_rank)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
@ -1020,7 +980,8 @@ class QKVParallelLinear(ColumnParallelLinear):
num_heads=self.num_kv_head_replicas,
shard_id=loaded_shard_id,
shard_offset=shard_offset,
shard_size=shard_size)
shard_size=shard_size,
tp_rank=self.tp_rank)
def weight_loader(self,
param: Parameter,
@ -1226,6 +1187,7 @@ class RowParallelLinear(LinearBase):
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
def __init__(
@ -1241,10 +1203,13 @@ class RowParallelLinear(LinearBase):
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
@ -1255,7 +1220,8 @@ class RowParallelLinear(LinearBase):
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
@ -1339,10 +1305,9 @@ class RowParallelLinear(LinearBase):
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None

View File

@ -69,6 +69,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
self.tp_disabled_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
@ -322,14 +323,24 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
# override tp_size and tp_rank if the module has disabled TP
if any(tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
@ -418,12 +429,16 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(rep_name, sub_name))
new_name = name.replace(rep_name, sub_name)
self.target_modules.append(new_name)
if module.disable_tp:
self.tp_disabled_modules.append(new_name)
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.

View File

@ -43,7 +43,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@ -435,12 +434,13 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedReplicatedLinear(
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj")
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,

View File

@ -51,14 +51,10 @@ from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
# yapf: enable
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@ -174,20 +170,22 @@ class Glm4vVisionMLP(nn.Module):
use_data_parallel: bool = False,
):
super().__init__()
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor):
@ -234,48 +232,32 @@ class Glm4vVisionAttention(nn.Module):
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.tp_rank = (0 if use_data_parallel else
parallel_state.get_tensor_model_parallel_rank())
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
if use_data_parallel:
self.qkv = ReplicatedLinear(
input_size=embed_dim,
output_size=3 * projection_size,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = ReplicatedLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj"
if quant_config else f"{prefix}.qkv",
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=False,
quant_config=quant_config,
# Change qkv prefix to align with GLM-4.5V-FP8 quantization cfg
prefix=f"{prefix}.qkv_proj" if quant_config else f"{prefix}.qkv",
disable_tp=use_data_parallel,
)
self.proj = RowParallelLinear(
input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
bias=False,
disable_tp=use_data_parallel,
)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
@ -494,41 +476,31 @@ class Glm4vPatchMerger(nn.Module):
) -> None:
super().__init__()
self.hidden_size = d_model
if use_data_parallel:
self.proj = ReplicatedLinear(
input_size=self.hidden_size,
output_size=self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
else:
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
)
self.proj = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=bias,
gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel,
)
self.post_projection_norm = nn.LayerNorm(self.hidden_size)
cls_gate_up = (MergedReplicatedLinear
if use_data_parallel else MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size,
output_sizes=[context_dim] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel,
)
cls_down = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down(
self.down_proj = RowParallelLinear(
context_dim,
self.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel,
)
self.act_fn = SiluAndMul()
self.extra_activation_func = nn.GELU()

View File

@ -48,7 +48,6 @@ from vllm.model_executor.layers.layernorm import RMSNorm
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
MergedReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
@ -178,22 +177,20 @@ class Qwen2_5_VisionMLP(nn.Module):
prefix: str = "",
use_data_parallel: bool = False):
super().__init__()
cls_gate_up_proj = (MergedReplicatedLinear if use_data_parallel else
MergedColumnParallelLinear)
self.gate_up_proj = cls_gate_up_proj(
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
prefix=f"{prefix}.gate_up_proj",
disable_tp=use_data_parallel)
cls_down_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.down_proj = cls_down_proj(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = act_fn
def forward(self, x: torch.Tensor):
@ -243,30 +240,21 @@ class Qwen2_5_VisionAttention(nn.Module):
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
if use_data_parallel:
self.qkv = ReplicatedLinear(embed_dim,
self.hidden_size_per_attention_head *
3 * num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv",
disable_tp=use_data_parallel)
else:
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
cls_proj = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.proj = cls_proj(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj",
disable_tp=use_data_parallel)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)

View File

@ -21,7 +21,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@ -667,35 +666,21 @@ class Step3VisionAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
if use_data_parallel:
self.qkv_proj = ReplicatedLinear(
self.embed_dim,
3 * self.q_size,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = ReplicatedLinear(
self.total_num_heads * self.head_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
else:
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=prefix,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.qkv_proj = QKVParallelLinear(
self.embed_dim,
self.head_dim,
self.total_num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
disable_tp=use_data_parallel,
)
self.out_proj = RowParallelLinear(self.embed_dim,
self.embed_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads,
@ -740,20 +725,18 @@ class Step3VisionMLP(nn.Module):
super().__init__()
self.config = config
self.activation_fn = get_act_fn(config.hidden_act)
cls_fc1 = (ReplicatedLinear
if use_data_parallel else ColumnParallelLinear)
self.fc1 = cls_fc1(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
cls_fc2 = (ReplicatedLinear
if use_data_parallel else RowParallelLinear)
self.fc2 = cls_fc2(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=prefix)
self.fc1 = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
disable_tp=use_data_parallel)
self.fc2 = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
disable_tp=use_data_parallel)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc1(hidden_states)

View File

@ -57,6 +57,8 @@ class BasevLLMParameter(Parameter):
weight_loader = _make_synced_weight_loader(weight_loader)
self._weight_loader = weight_loader
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
@property
def weight_loader(self):
@ -116,10 +118,10 @@ class _ColumnvLLMParameter(BasevLLMParameter):
return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.output_dim]
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
assert self.data.shape == loaded_weight.shape
self.data.copy_(loaded_weight)
@ -127,6 +129,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
@ -137,11 +140,11 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@ -161,8 +164,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
tp_rank = get_tensor_model_parallel_rank()
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank //
num_heads)
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
loaded_weight = loaded_weight.narrow(self.output_dim,
@ -189,10 +192,10 @@ class RowvLLMParameter(BasevLLMParameter):
return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.data.shape[self.input_dim]
loaded_weight = loaded_weight.narrow(self.input_dim,
tp_rank * shard_size, shard_size)
self.tp_rank * shard_size,
shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
@ -414,9 +417,6 @@ class SharedWeightParameter(BasevLLMParameter):
"weight_loader": self._fake_weight_loader
}
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > 1:
raise NotImplementedError(f"{self.__class__.__name__} does not "
"currently support tensor parallelism")