mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-03 22:35:22 +08:00
Support more parallel styles in Transformers backend TP (#22651)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
65abe111a3
commit
458e74eb90
@ -107,10 +107,17 @@ def replace_linear_class(
|
||||
raise ValueError(
|
||||
f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
vllm_linear_cls = {
|
||||
"colwise": ColumnParallelLinear,
|
||||
"rowwise": RowParallelLinear,
|
||||
}.get(style, ReplicatedLinear)
|
||||
vllm_linear_cls, vllm_linear_kwargs = {
|
||||
"colwise": (ColumnParallelLinear, {}),
|
||||
"colwise_rep": (ColumnParallelLinear, {
|
||||
"gather_output": True
|
||||
}),
|
||||
"rowwise": (RowParallelLinear, {}),
|
||||
"rowwise_rep": (RowParallelLinear, {
|
||||
"input_is_parallel": False
|
||||
}),
|
||||
"replicate": (ReplicatedLinear, {}),
|
||||
}.get(style, (ReplicatedLinear, {}))
|
||||
|
||||
return vllm_linear_cls(
|
||||
input_size=linear.in_features,
|
||||
@ -118,6 +125,7 @@ def replace_linear_class(
|
||||
bias=linear.bias is not None,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
**vllm_linear_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -506,7 +514,7 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
tp_plan[".*"] = "replicated"
|
||||
tp_plan[".*"] = "replicate"
|
||||
|
||||
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||
for child_name, child_module in module.named_children():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user