diff --git a/vllm/model_executor/models/deepseek_vl2.py b/vllm/model_executor/models/deepseek_vl2.py index 1bd2802a8683..5eab02b17151 100644 --- a/vllm/model_executor/models/deepseek_vl2.py +++ b/vllm/model_executor/models/deepseek_vl2.py @@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if isinstance(module, nn.Linear): parent, attr_name = self._get_parent_and_attr(vit, name) if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1": - new_linear = replace_linear_class(module, "colwise", - quant_config) + new_linear = replace_linear_class(module, + "colwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) elif isinstance(parent, timm.layers.Mlp) and attr_name == "fc2": - new_linear = replace_linear_class(module, "rowwise", - quant_config) + new_linear = replace_linear_class(module, + "rowwise", + quant_config, + prefix=name) setattr(parent, attr_name, new_linear) return vit diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index edf3dddb1bad..f7ced6134da5 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -106,8 +106,11 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool: def replace_linear_class( - linear: nn.Linear, style: Literal["colwise", "rowwise"], - quant_config: QuantizationConfig + linear: nn.Linear, + style: Literal["colwise", "rowwise"], + quant_config: QuantizationConfig, + *, + prefix: str = "", ) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]: """ Replace nn.Linear with one of vLLM's tensor parallel linear classes. @@ -141,6 +144,7 @@ def replace_linear_class( output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + prefix=prefix, return_bias=False, **vllm_linear_kwargs, ) @@ -557,8 +561,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP): generator = (p for p in tp_plan if re.match(p, qual_name)) pattern = next(generator, None) style = tp_plan.get(pattern, "replicate") - new_module = replace_linear_class(child_module, style, - self.quant_config) + new_module = replace_linear_class(child_module, + style, + self.quant_config, + prefix=qual_name) setattr(module, child_name, new_module) log_replacement(qual_name, child_module, new_module) else: