mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-04 14:34:05 +08:00
Fix Transformers backend tensor parallel for multimodal models (#22673)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
45c3936e94
commit
d0a6301588
@ -505,30 +505,47 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
Apply the model's tensor parallelization plan.
|
||||
Currently only supports linear layers.
|
||||
"""
|
||||
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
|
||||
# Look for tp plans in all of the PreTrainedModels found in self.model
|
||||
is_pretrained_model = lambda m: isinstance(m, PreTrainedModel)
|
||||
supports_tp_plan = lambda m: m.config.base_model_tp_plan is not None
|
||||
pretrained_models = filter(is_pretrained_model, self.model.modules())
|
||||
models_with_tp_plan = filter(supports_tp_plan, pretrained_models)
|
||||
|
||||
if not tp_plan and self.tp_size > 1:
|
||||
if not any(models_with_tp_plan) and self.tp_size > 1:
|
||||
raise ValueError(
|
||||
f"{type(self.model)} does not support tensor parallel yet!")
|
||||
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
tp_plan[".*"] = "replicate"
|
||||
def _tensor_parallel(module: nn.Module,
|
||||
prefix: str = "",
|
||||
tp_plan=None):
|
||||
tp_plan = tp_plan or {}
|
||||
|
||||
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||
# If the current module is a PreTrainedModel, set the tp_plan for
|
||||
# all of its children
|
||||
if isinstance(module, PreTrainedModel):
|
||||
tp_plan = module.config.base_model_tp_plan or {}
|
||||
tp_plan = {
|
||||
maybe_prefix(prefix, k): v
|
||||
for k, v in tp_plan.items()
|
||||
}
|
||||
|
||||
# Some weight loaders expect linear layers to inherit from vLLM's
|
||||
# LinearBase class, so we set a default style which causes any
|
||||
# unspecified linear layers to be replaced with ReplicatedLinear
|
||||
for child_name, child_module in module.named_children():
|
||||
qual_name = maybe_prefix(prefix, child_name)
|
||||
for pattern, style in tp_plan.items():
|
||||
if re.match(pattern, qual_name) and isinstance(
|
||||
child_module, nn.Linear):
|
||||
new_module = replace_linear_class(
|
||||
child_module, style, self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
break
|
||||
if isinstance(child_module, nn.Linear):
|
||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||
pattern = next(generator, None)
|
||||
style = tp_plan.get(pattern, "replicate")
|
||||
new_module = replace_linear_class(child_module, style,
|
||||
self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
log_replacement(qual_name, child_module, new_module)
|
||||
else:
|
||||
_tensor_parallel(child_module, prefix=qual_name)
|
||||
_tensor_parallel(child_module,
|
||||
prefix=qual_name,
|
||||
tp_plan=tp_plan)
|
||||
|
||||
_tensor_parallel(self.model)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user