mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:57:02 +08:00
update ptpc path; bug fixes
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
dd001064c0
commit
7d361487f7
@ -82,10 +82,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
else:
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
act_q_static=self.is_static_input_scheme,
|
||||
act_q_group_shape=self.act_q_group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -394,10 +394,11 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
else:
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
is_static_input_scheme=self.act_q_static,
|
||||
act_q_static=self.act_q_static,
|
||||
act_q_group_shape=self.act_q_group_shape,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
activation_group_shape=self.act_q_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
|
||||
@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
||||
ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
PerTensorTorchScaledMMLinearKernel,
|
||||
RowWiseTorchScaledMMLinearKernel,
|
||||
@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||
XLAScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -133,12 +133,13 @@ def choose_scaled_mm_linear_kernel(
|
||||
def init_fp8_linear_kernel(
|
||||
act_q_static: bool,
|
||||
act_q_group_shape: GroupShape,
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy,
|
||||
out_dtype: torch.dtype,
|
||||
module_name: str,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
is_static_input_scheme=act_q_static,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=act_q_group_shape,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
@ -16,11 +16,16 @@ from vllm.model_executor.layers.quantization.fp8 import (
|
||||
Fp8KVCacheMethod,
|
||||
Fp8LinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
is_layer_skipped,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
@ -98,11 +103,15 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
)
|
||||
super().__init__(quant_config=quant_config)
|
||||
# Force weight quantization
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
act_q_static=False,
|
||||
act_q_group_shape=GroupShape.PER_TOKEN,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
@ -127,11 +136,6 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=None,
|
||||
input_scale_ub=None,
|
||||
bias=bias,
|
||||
return self.fp8_linear_kernel.apply_weights(
|
||||
layer, x, bias
|
||||
)
|
||||
|
||||
@ -174,10 +174,11 @@ class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
act_q_static=self.is_static_input_scheme,
|
||||
act_q_group_shape=self.act_quant_group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
activation_group_shape=self.act_quant_group_shape,
|
||||
out_dtype=self.out_dtype,
|
||||
module_name=self.__class__.__name__
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user