update ptpc path; bug fixes

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-31 14:52:51 +00:00
parent dd001064c0
commit 7d361487f7
6 changed files with 28 additions and 20 deletions

View File

@ -82,10 +82,11 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
else: else:
weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy] weight_quant_strategy = QUANT_STRATEGY_MAP[self.strategy]
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear_kernel = init_fp8_linear_kernel(
is_static_input_scheme=self.is_static_input_scheme, act_q_static=self.is_static_input_scheme,
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=weight_quant_strategy, weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,
module_name=self.__class__.__name__
) )
@classmethod @classmethod

View File

@ -394,10 +394,11 @@ class Fp8LinearMethod(LinearMethodBase):
) )
else: else:
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear_kernel = init_fp8_linear_kernel(
is_static_input_scheme=self.act_q_static, act_q_static=self.act_q_static,
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
activation_group_shape=self.act_q_group_shape,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,
module_name=self.__class__.__name__
) )
def create_weights( def create_weights(

View File

@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearLayerConfig, ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy, ScaledMMLinearQuantStrategy,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.torch import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchScaledMMLinearKernel, ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel, PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel, RowWiseTorchScaledMMLinearKernel,
@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel, XLAScaledMMLinearKernel,
) )
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
from vllm.vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
logger = init_logger(__name__) logger = init_logger(__name__)
@ -133,12 +133,13 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel( def init_fp8_linear_kernel(
act_q_static: bool, act_q_static: bool,
act_q_group_shape: GroupShape, act_q_group_shape: GroupShape,
weight_quant_strategy: ScaledMMLinearQuantStrategy,
out_dtype: torch.dtype, out_dtype: torch.dtype,
module_name: str, module_name: str,
) -> FP8ScaledMMLinearKernel: ) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
is_static_input_scheme=act_q_static, is_static_input_scheme=act_q_static,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR, weight_quant_strategy=weight_quant_strategy,
activation_group_shape=act_q_group_shape, activation_group_shape=act_q_group_shape,
out_dtype=out_dtype, out_dtype=out_dtype,
) )

View File

@ -16,11 +16,16 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
is_layer_skipped, is_layer_skipped,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -98,11 +103,15 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
) )
super().__init__(quant_config=quant_config) super().__init__(quant_config=quant_config)
# Force weight quantization # Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear_kernel = init_fp8_linear_kernel(
self.fp8_linear = Fp8LinearOp( act_q_static=False,
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN act_q_group_shape=GroupShape.PER_TOKEN,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
@ -127,11 +136,6 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.fp8_linear_kernel.apply_weights(
input=x, layer, x, bias
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias,
) )

View File

@ -174,10 +174,11 @@ class QuarkW8A8Fp8(QuarkScheme):
weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme] weight_quant_strategy = QUANT_STRATEGY_MAP[self.weight_qscheme]
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear_kernel = init_fp8_linear_kernel(
is_static_input_scheme=self.is_static_input_scheme, act_q_static=self.is_static_input_scheme,
act_q_group_shape=self.act_quant_group_shape,
weight_quant_strategy=weight_quant_strategy, weight_quant_strategy=weight_quant_strategy,
activation_group_shape=self.act_quant_group_shape,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,
module_name=self.__class__.__name__
) )
def apply_weights( def apply_weights(