mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 21:25:58 +08:00
[Misc] Add ignored layers for fp8 quantization (#6657)
This commit is contained in:
parent
38c4b7e863
commit
0eb0757bef
@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
FUSED_LAYER_NAME_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
class CompressionFormat(Enum):
|
class CompressionFormat(Enum):
|
||||||
dense = "dense"
|
dense = "dense"
|
||||||
@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
|
|||||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||||
|
|
||||||
|
|
||||||
# fused_name: List[shard_name]
|
|
||||||
_FUSED_LAYER_NAME_MAPPING = {
|
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def should_ignore_layer(layer_name: Optional[str],
|
def should_ignore_layer(layer_name: Optional[str],
|
||||||
ignore: Iterable[str]) -> bool:
|
ignore: Iterable[str]) -> bool:
|
||||||
if layer_name is None:
|
if layer_name is None:
|
||||||
@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
|
|||||||
# in the safetensors checkpoint. So, we convert the name
|
# in the safetensors checkpoint. So, we convert the name
|
||||||
# from the fused version to unfused + check to make sure that
|
# from the fused version to unfused + check to make sure that
|
||||||
# each shard of the fused layer has the same scheme.
|
# each shard of the fused layer has the same scheme.
|
||||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||||
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
|
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||||
|
|
||||||
# Convert fused_name --> [shard_names]
|
# Convert fused_name --> [shard_names]
|
||||||
shard_names = [
|
shard_names = [
|
||||||
|
|||||||
@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, create_per_channel_scale_param)
|
apply_fp8_linear, create_per_channel_scale_param)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -18,14 +20,6 @@ from vllm.platforms import current_platform
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Note: this is a hack. We should update each model to register the
|
|
||||||
# stacked params and get it from there instead in a future PR.
|
|
||||||
# fused_name: List[shard_name]
|
|
||||||
_FUSED_LAYER_NAME_MAPPING = {
|
|
||||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
|
||||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class FBGEMMFp8Config(QuantizationConfig):
|
class FBGEMMFp8Config(QuantizationConfig):
|
||||||
"""Config class for FBGEMM Fp8."""
|
"""Config class for FBGEMM Fp8."""
|
||||||
@ -62,37 +56,10 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
|
||||||
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
|
||||||
|
|
||||||
def _is_layer_skipped(self, prefix: str) -> bool:
|
|
||||||
# prefix: model.layers.0.self_attn.q_proj
|
|
||||||
# proj_name: q_proj
|
|
||||||
proj_name = prefix.split(".")[-1]
|
|
||||||
if proj_name in _FUSED_LAYER_NAME_MAPPING:
|
|
||||||
shard_prefixes = [
|
|
||||||
prefix.replace(proj_name, shard_proj_name)
|
|
||||||
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
|
|
||||||
]
|
|
||||||
|
|
||||||
is_skipped = None
|
|
||||||
for shard_prefix in shard_prefixes:
|
|
||||||
is_shard_skipped = shard_prefix in self.ignore_list
|
|
||||||
|
|
||||||
if is_skipped is None:
|
|
||||||
is_skipped = is_shard_skipped
|
|
||||||
elif is_shard_skipped != is_skipped:
|
|
||||||
raise ValueError(
|
|
||||||
f"Detected some but not all shards of {prefix} "
|
|
||||||
"are quantized. All shards of fused layers "
|
|
||||||
"to have the same precision.")
|
|
||||||
else:
|
|
||||||
is_skipped = prefix in self.ignore_list
|
|
||||||
|
|
||||||
assert is_skipped is not None
|
|
||||||
return is_skipped
|
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if self._is_layer_skipped(prefix):
|
if is_layer_skipped(prefix, self.ignore_list):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return FBGEMMFp8LinearMethod(self)
|
return FBGEMMFp8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -8,12 +8,15 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
fused_moe)
|
fused_moe)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
is_layer_skipped)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||||
@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
self,
|
self,
|
||||||
is_checkpoint_fp8_serialized: bool = False,
|
is_checkpoint_fp8_serialized: bool = False,
|
||||||
activation_scheme: str = "dynamic",
|
activation_scheme: str = "dynamic",
|
||||||
|
ignored_layers: Optional[List[str]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
||||||
if is_checkpoint_fp8_serialized:
|
if is_checkpoint_fp8_serialized:
|
||||||
@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported activation scheme {activation_scheme}")
|
f"Unsupported activation scheme {activation_scheme}")
|
||||||
self.activation_scheme = activation_scheme
|
self.activation_scheme = activation_scheme
|
||||||
|
self.ignored_layers = ignored_layers or []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
|
|||||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||||
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
|
||||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
|
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||||
activation_scheme=activation_scheme)
|
activation_scheme=activation_scheme,
|
||||||
|
ignored_layers=ignored_layers)
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Fp8MoEMethod(self)
|
return Fp8MoEMethod(self)
|
||||||
|
|||||||
@ -1,10 +1,48 @@
|
|||||||
"""This file is used for /tests and /benchmarks"""
|
"""This file is used for /tests and /benchmarks"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
SUPPORTED_NUM_BITS = [4, 8]
|
SUPPORTED_NUM_BITS = [4, 8]
|
||||||
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||||
|
|
||||||
|
# Note: this is a hack. We should update each model to register the
|
||||||
|
# stacked params and get it from there instead in a future PR.
|
||||||
|
# fused_name: List[shard_name]
|
||||||
|
FUSED_LAYER_NAME_MAPPING = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
|
||||||
|
# prefix: model.layers.0.self_attn.q_proj
|
||||||
|
# proj_name: q_proj
|
||||||
|
proj_name = prefix.split(".")[-1]
|
||||||
|
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||||
|
shard_prefixes = [
|
||||||
|
prefix.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||||
|
]
|
||||||
|
|
||||||
|
is_skipped = None
|
||||||
|
for shard_prefix in shard_prefixes:
|
||||||
|
is_shard_skipped = shard_prefix in ignored_layers
|
||||||
|
|
||||||
|
if is_skipped is None:
|
||||||
|
is_skipped = is_shard_skipped
|
||||||
|
elif is_shard_skipped != is_skipped:
|
||||||
|
raise ValueError(
|
||||||
|
f"Detected some but not all shards of {prefix} "
|
||||||
|
"are quantized. All shards of fused layers "
|
||||||
|
"to have the same precision.")
|
||||||
|
else:
|
||||||
|
is_skipped = prefix in ignored_layers
|
||||||
|
|
||||||
|
assert is_skipped is not None
|
||||||
|
return is_skipped
|
||||||
|
|
||||||
|
|
||||||
def get_pack_factor(num_bits):
|
def get_pack_factor(num_bits):
|
||||||
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
assert num_bits in SUPPORTED_NUM_BITS, f"Unsupported num_bits = {num_bits}"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user