mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 14:27:06 +08:00
Fix Transformers backend tensor parallel for multimodal models (#22673)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
45c3936e94
commit
d0a6301588
@ -505,30 +505,47 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
Apply the model's tensor parallelization plan.
|
Apply the model's tensor parallelization plan.
|
||||||
Currently only supports linear layers.
|
Currently only supports linear layers.
|
||||||
"""
|
"""
|
||||||
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
|
# Look for tp plans in all of the PreTrainedModels found in self.model
|
||||||
|
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
|
||||||
|
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
|
||||||
|
pretrained_models = filter(is_pretrained_model, self.model.modules())
|
||||||
|
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
|
||||||
|
|
||||||
if not tp_plan and self.tp_size > 1:
|
if not any(models_with_tp_plan) and self.tp_size > 1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{type(self.model)} does not support tensor parallel yet!")
|
f"{type(self.model)} does not support tensor parallel yet!")
|
||||||
|
|
||||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
def _tensor_parallel(module: nn.Module,
|
||||||
# LinearBase class, so we set a default style which causes any
|
prefix: str = "",
|
||||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
tp_plan=None):
|
||||||
tp_plan[".*"] = "replicate"
|
tp_plan = tp_plan or {}
|
||||||
|
|
||||||
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||||
|
# all of its children
|
||||||
|
if isinstance(module, PreTrainedModel):
|
||||||
|
tp_plan = module.config.base_model_tp_plan or {}
|
||||||
|
tp_plan = {
|
||||||
|
maybe_prefix(prefix, k): v
|
||||||
|
for k, v in tp_plan.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||||
|
# LinearBase class, so we set a default style which causes any
|
||||||
|
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||||
for child_name, child_module in module.named_children():
|
for child_name, child_module in module.named_children():
|
||||||
qual_name = maybe_prefix(prefix, child_name)
|
qual_name = maybe_prefix(prefix, child_name)
|
||||||
for pattern, style in tp_plan.items():
|
if isinstance(child_module, nn.Linear):
|
||||||
if re.match(pattern, qual_name) and isinstance(
|
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||||
child_module, nn.Linear):
|
pattern = next(generator, None)
|
||||||
new_module = replace_linear_class(
|
style = tp_plan.get(pattern, "replicate")
|
||||||
child_module, style, self.quant_config)
|
new_module = replace_linear_class(child_module, style,
|
||||||
setattr(module, child_name, new_module)
|
self.quant_config)
|
||||||
log_replacement(qual_name, child_module, new_module)
|
setattr(module, child_name, new_module)
|
||||||
break
|
log_replacement(qual_name, child_module, new_module)
|
||||||
else:
|
else:
|
||||||
_tensor_parallel(child_module, prefix=qual_name)
|
_tensor_parallel(child_module,
|
||||||
|
prefix=qual_name,
|
||||||
|
tp_plan=tp_plan)
|
||||||
|
|
||||||
_tensor_parallel(self.model)
|
_tensor_parallel(self.model)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user