mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 15:17:02 +08:00
Support more parallel styles in Transformers backend TP (#22651)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
65abe111a3
commit
458e74eb90
@ -107,10 +107,17 @@ def replace_linear_class(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported parallel style type {type(style)}, expected str")
|
f"Unsupported parallel style type {type(style)}, expected str")
|
||||||
|
|
||||||
vllm_linear_cls = {
|
vllm_linear_cls, vllm_linear_kwargs = {
|
||||||
"colwise": ColumnParallelLinear,
|
"colwise": (ColumnParallelLinear, {}),
|
||||||
"rowwise": RowParallelLinear,
|
"colwise_rep": (ColumnParallelLinear, {
|
||||||
}.get(style, ReplicatedLinear)
|
"gather_output": True
|
||||||
|
}),
|
||||||
|
"rowwise": (RowParallelLinear, {}),
|
||||||
|
"rowwise_rep": (RowParallelLinear, {
|
||||||
|
"input_is_parallel": False
|
||||||
|
}),
|
||||||
|
"replicate": (ReplicatedLinear, {}),
|
||||||
|
}.get(style, (ReplicatedLinear, {}))
|
||||||
|
|
||||||
return vllm_linear_cls(
|
return vllm_linear_cls(
|
||||||
input_size=linear.in_features,
|
input_size=linear.in_features,
|
||||||
@ -118,6 +125,7 @@ def replace_linear_class(
|
|||||||
bias=linear.bias is not None,
|
bias=linear.bias is not None,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
|
**vllm_linear_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -506,7 +514,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||||
# LinearBase class, so we set a default style which causes any
|
# LinearBase class, so we set a default style which causes any
|
||||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||||
tp_plan[".*"] = "replicated"
|
tp_plan[".*"] = "replicate"
|
||||||
|
|
||||||
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||||
for child_name, child_module in module.named_children():
|
for child_name, child_module in module.named_children():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user