mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 18:45:21 +08:00
[Feature] models: pass layer prefix to replace_linear_class for per-layer quantization routing. Addresses #23239 (#23556)
Signed-off-by: Shrey Gupta <shreyg1303@gmail.com>
This commit is contained in:
parent
a69693e38f
commit
1b7b161a09
@ -408,13 +408,17 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
parent, attr_name = self._get_parent_and_attr(vit, name)
|
parent, attr_name = self._get_parent_and_attr(vit, name)
|
||||||
if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
|
if isinstance(parent, timm.layers.Mlp) and attr_name == "fc1":
|
||||||
new_linear = replace_linear_class(module, "colwise",
|
new_linear = replace_linear_class(module,
|
||||||
quant_config)
|
"colwise",
|
||||||
|
quant_config,
|
||||||
|
prefix=name)
|
||||||
setattr(parent, attr_name, new_linear)
|
setattr(parent, attr_name, new_linear)
|
||||||
elif isinstance(parent,
|
elif isinstance(parent,
|
||||||
timm.layers.Mlp) and attr_name == "fc2":
|
timm.layers.Mlp) and attr_name == "fc2":
|
||||||
new_linear = replace_linear_class(module, "rowwise",
|
new_linear = replace_linear_class(module,
|
||||||
quant_config)
|
"rowwise",
|
||||||
|
quant_config,
|
||||||
|
prefix=name)
|
||||||
setattr(parent, attr_name, new_linear)
|
setattr(parent, attr_name, new_linear)
|
||||||
|
|
||||||
return vit
|
return vit
|
||||||
|
|||||||
@ -106,8 +106,11 @@ def can_enable_torch_compile(vllm_config: VllmConfig) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def replace_linear_class(
|
def replace_linear_class(
|
||||||
linear: nn.Linear, style: Literal["colwise", "rowwise"],
|
linear: nn.Linear,
|
||||||
quant_config: QuantizationConfig
|
style: Literal["colwise", "rowwise"],
|
||||||
|
quant_config: QuantizationConfig,
|
||||||
|
*,
|
||||||
|
prefix: str = "",
|
||||||
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
) -> Union[ColumnParallelLinear, RowParallelLinear, ReplicatedLinear]:
|
||||||
"""
|
"""
|
||||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||||
@ -141,6 +144,7 @@ def replace_linear_class(
|
|||||||
output_size=linear.out_features,
|
output_size=linear.out_features,
|
||||||
bias=linear.bias is not None,
|
bias=linear.bias is not None,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=prefix,
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
**vllm_linear_kwargs,
|
**vllm_linear_kwargs,
|
||||||
)
|
)
|
||||||
@ -557,8 +561,10 @@ class TransformersBase(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
generator = (p for p in tp_plan if re.match(p, qual_name))
|
generator = (p for p in tp_plan if re.match(p, qual_name))
|
||||||
pattern = next(generator, None)
|
pattern = next(generator, None)
|
||||||
style = tp_plan.get(pattern, "replicate")
|
style = tp_plan.get(pattern, "replicate")
|
||||||
new_module = replace_linear_class(child_module, style,
|
new_module = replace_linear_class(child_module,
|
||||||
self.quant_config)
|
style,
|
||||||
|
self.quant_config,
|
||||||
|
prefix=qual_name)
|
||||||
setattr(module, child_name, new_module)
|
setattr(module, child_name, new_module)
|
||||||
log_replacement(qual_name, child_module, new_module)
|
log_replacement(qual_name, child_module, new_module)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user