format; update fbgemm path

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-10-31 15:08:19 +00:00
parent 1f65cd56e5
commit 5fbe76bc0a
6 changed files with 27 additions and 24 deletions

View File

@ -86,7 +86,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__
module_name=self.__class__.__name__,
)
@classmethod

View File

@ -18,6 +18,12 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
@ -96,6 +102,14 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
)
self.out_dtype = torch.get_default_dtype()
self.fp8_linear_kernel = init_fp8_linear_kernel(
act_q_static=False,
act_q_group_shape=GroupShape.PER_TOKEN,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__,
)
def create_weights(
self,
layer: torch.nn.Module,
@ -184,12 +198,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)

View File

@ -398,7 +398,7 @@ class Fp8LinearMethod(LinearMethodBase):
act_q_group_shape=self.act_q_group_shape,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__
module_name=self.__class__.__name__,
)
def create_weights(

View File

@ -15,6 +15,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
ROCmScaledMMLinearKernel,
)
@ -25,19 +30,14 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
ScaledMMLinearLayerConfig,
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
ChannelWiseTorchScaledMMLinearKernel,
PerTensorTorchScaledMMLinearKernel,
RowWiseTorchScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
@ -133,7 +133,7 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel(
act_q_static: bool,
act_q_group_shape: GroupShape,
weight_quant_strategy: ScaledMMLinearQuantStrategy,
weight_quant_strategy: ScaledMMLinearQuantStrategy,
out_dtype: torch.dtype,
module_name: str,
) -> FP8ScaledMMLinearKernel:

View File

@ -108,10 +108,9 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
act_q_group_shape=GroupShape.PER_TOKEN,
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__
module_name=self.__class__.__name__,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
@ -136,6 +135,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.fp8_linear_kernel.apply_weights(
layer, x, bias
)
return self.fp8_linear_kernel.apply_weights(layer, x, bias)

View File

@ -178,7 +178,7 @@ class QuarkW8A8Fp8(QuarkScheme):
act_q_group_shape=self.act_quant_group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=self.out_dtype,
module_name=self.__class__.__name__
module_name=self.__class__.__name__,
)
def apply_weights(