Fix Transformers backend tensor parallel for multimodal models (#22673)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-08-13 01:12:30 +01:00 committed by GitHub
parent 45c3936e94
commit d0a6301588
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -505,30 +505,47 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
# Look for tp plans in all of the PreTrainedModels found in self.model
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
pretrained_models = filter(is_pretrained_model, self.model.modules())
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
if not tp_plan and self.tp_size > 1:
if not any(models_with_tp_plan) and self.tp_size > 1:
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
tp_plan[".*"] = "replicate"
def _tensor_parallel(module: nn.Module,
prefix: str = "",
tp_plan=None):
tp_plan = tp_plan or {}
def _tensor_parallel(module: nn.Module, prefix: str = ""):
# If the current module is a PreTrainedModel, set the tp_plan for
# all of its children
if isinstance(module, PreTrainedModel):
tp_plan = module.config.base_model_tp_plan or {}
tp_plan = {
maybe_prefix(prefix, k): v
for k, v in tp_plan.items()
}
# Some weight loaders expect linear layers to inherit from vLLM's
# LinearBase class, so we set a default style which causes any
# unspecified linear layers to be replaced with ReplicatedLinear
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
break
if isinstance(child_module, nn.Linear):
generator = (p for p in tp_plan if re.match(p, qual_name))
pattern = next(generator, None)
style = tp_plan.get(pattern, "replicate")
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(child_module,
prefix=qual_name,
tp_plan=tp_plan)
_tensor_parallel(self.model)