mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 05:42:15 +08:00
[Feature][Quantization] auto_round format add support for regex (#24024)
Signed-off-by: n1ck-guo <heng.guo@intel.com> Signed-off-by: Heng Guo <heng.guo@intel.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
parent
8ae169286f
commit
29350922c6
@ -4,6 +4,7 @@
|
|||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -128,11 +129,44 @@ class AutoRoundConfig(QuantizationConfig):
|
|||||||
|
|
||||||
def get_layer_config(self, layer, layer_name: str):
|
def get_layer_config(self, layer, layer_name: str):
|
||||||
def get_config(name: str, quantized: bool = True):
|
def get_config(name: str, quantized: bool = True):
|
||||||
cfg = self.extra_config.get(name, {}) if self.extra_config else {}
|
if not self.extra_config:
|
||||||
|
return (
|
||||||
|
self.weight_bits if quantized else 16,
|
||||||
|
self.group_size if quantized else -1,
|
||||||
|
self.sym if quantized else True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# exact match first
|
||||||
|
if name in self.extra_config:
|
||||||
|
cfg = self.extra_config[name]
|
||||||
|
return (
|
||||||
|
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||||
|
cfg.get("group_size", self.group_size if quantized else -1),
|
||||||
|
cfg.get("sym", self.sym if quantized else True),
|
||||||
|
)
|
||||||
|
|
||||||
|
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
|
||||||
|
for pattern, cfg in self.extra_config.items():
|
||||||
|
if not isinstance(pattern, str) or not any(
|
||||||
|
c in REGEX_SPECIAL_CHARS for c in pattern
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if re.search(re.compile(pattern), name) is not None:
|
||||||
|
return (
|
||||||
|
cfg.get("bits", self.weight_bits if quantized else 16),
|
||||||
|
cfg.get("group_size", self.group_size if quantized else -1),
|
||||||
|
cfg.get("sym", self.sym if quantized else True),
|
||||||
|
)
|
||||||
|
except re.error:
|
||||||
|
# Invalid regex, ignore.
|
||||||
|
continue
|
||||||
|
|
||||||
return (
|
return (
|
||||||
cfg.get("bits", self.weight_bits if quantized else 16),
|
self.weight_bits if quantized else 16,
|
||||||
cfg.get("group_size", self.group_size if quantized else -1),
|
self.group_size if quantized else -1,
|
||||||
cfg.get("sym", self.sym if quantized else True),
|
self.sym if quantized else True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Exact match from config
|
# 1. Exact match from config
|
||||||
@ -176,7 +210,7 @@ class AutoRoundConfig(QuantizationConfig):
|
|||||||
f"consistent quant config for {sub_names}"
|
f"consistent quant config for {sub_names}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. Fallback
|
# 5. Fallback or try a regular expression match
|
||||||
return get_config(layer_name, quantized)
|
return get_config(layer_name, quantized)
|
||||||
|
|
||||||
def check_quantized(self, weight_bits: int) -> bool:
|
def check_quantized(self, weight_bits: int) -> bool:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user