Improve validation of TP in Transformers backend (#15540)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-26 14:26:48 +00:00 committed by GitHub
parent 1aa162e030
commit c091c0a588
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -229,7 +229,10 @@ class TransformersModel(nn.Module):
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if self.tp_size > 1 and self.config.base_model_tp_plan is None:
if not self.model.supports_tp_plan:
if self.tp_size <= 1:
return
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")