Support more parallel styles in Transformers backend TP (#22651)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-11 18:42:48 +01:00 committed by GitHub
parent 65abe111a3
commit 458e74eb90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -107,10 +107,17 @@ def replace_linear_class(
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear)
vllm_linear_cls, vllm_linear_kwargs = {
"colwise": (ColumnParallelLinear, {}),
"colwise_rep": (ColumnParallelLinear, {
"gather_output": True
}),
"rowwise": (RowParallelLinear, {}),
"rowwise_rep": (RowParallelLinear, {
"input_is_parallel": False
}),
"replicate": (ReplicatedLinear, {}),
}.get(style, (ReplicatedLinear, {}))
return vllm_linear_cls(
input_size=linear.in_features,
@ -118,6 +125,7 @@ def replace_linear_class(
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
**vllm_linear_kwargs,
)
@ -506,7 +514,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
tp_plan[".*"] = "replicated"
tp_plan[".*"] = "replicate"
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():