mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-26 11:47:08 +08:00
format; update fbgemm path
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
1f65cd56e5
commit
5fbe76bc0a
@ -86,7 +86,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
act_q_group_shape=self.act_q_group_shape,
|
act_q_group_shape=self.act_q_group_shape,
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -18,6 +18,12 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
init_fp8_linear_kernel,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import (
|
||||||
|
ScaledMMLinearQuantStrategy,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
apply_fp8_marlin_linear,
|
apply_fp8_marlin_linear,
|
||||||
prepare_fp8_layer_for_marlin,
|
prepare_fp8_layer_for_marlin,
|
||||||
@ -96,6 +102,14 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
)
|
)
|
||||||
self.out_dtype = torch.get_default_dtype()
|
self.out_dtype = torch.get_default_dtype()
|
||||||
|
|
||||||
|
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||||
|
act_q_static=False,
|
||||||
|
act_q_group_shape=GroupShape.PER_TOKEN,
|
||||||
|
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
||||||
|
out_dtype=self.out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
@ -184,12 +198,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.fp8_linear.apply(
|
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||||
input=x,
|
|
||||||
weight=layer.weight,
|
|
||||||
weight_scale=layer.weight_scale,
|
|
||||||
out_dtype=self.out_dtype,
|
|
||||||
input_scale=None,
|
|
||||||
input_scale_ub=layer.input_scale_ub,
|
|
||||||
bias=bias,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -398,7 +398,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
act_q_group_shape=self.act_q_group_shape,
|
act_q_group_shape=self.act_q_group_shape,
|
||||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
|
|||||||
@ -15,6 +15,11 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
|||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||||
CutlassScaledMMLinearKernel,
|
CutlassScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
||||||
|
ChannelWiseTorchScaledMMLinearKernel,
|
||||||
|
PerTensorTorchScaledMMLinearKernel,
|
||||||
|
RowWiseTorchScaledMMLinearKernel,
|
||||||
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||||
ROCmScaledMMLinearKernel,
|
ROCmScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
@ -25,19 +30,14 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
|
|||||||
ScaledMMLinearLayerConfig,
|
ScaledMMLinearLayerConfig,
|
||||||
ScaledMMLinearQuantStrategy,
|
ScaledMMLinearQuantStrategy,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.pytorch import (
|
|
||||||
ChannelWiseTorchScaledMMLinearKernel,
|
|
||||||
PerTensorTorchScaledMMLinearKernel,
|
|
||||||
RowWiseTorchScaledMMLinearKernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
|
||||||
TritonScaledMMLinearKernel,
|
TritonScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
|
||||||
XLAScaledMMLinearKernel,
|
XLAScaledMMLinearKernel,
|
||||||
)
|
)
|
||||||
from vllm.platforms import PlatformEnum, current_platform
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.platforms import PlatformEnum, current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ def choose_scaled_mm_linear_kernel(
|
|||||||
def init_fp8_linear_kernel(
|
def init_fp8_linear_kernel(
|
||||||
act_q_static: bool,
|
act_q_static: bool,
|
||||||
act_q_group_shape: GroupShape,
|
act_q_group_shape: GroupShape,
|
||||||
weight_quant_strategy: ScaledMMLinearQuantStrategy,
|
weight_quant_strategy: ScaledMMLinearQuantStrategy,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
module_name: str,
|
module_name: str,
|
||||||
) -> FP8ScaledMMLinearKernel:
|
) -> FP8ScaledMMLinearKernel:
|
||||||
|
|||||||
@ -108,10 +108,9 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
|||||||
act_q_group_shape=GroupShape.PER_TOKEN,
|
act_q_group_shape=GroupShape.PER_TOKEN,
|
||||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
weight_quant_strategy=ScaledMMLinearQuantStrategy.CHANNEL,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||||
|
|
||||||
@ -136,6 +135,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.fp8_linear_kernel.apply_weights(
|
return self.fp8_linear_kernel.apply_weights(layer, x, bias)
|
||||||
layer, x, bias
|
|
||||||
)
|
|
||||||
|
|||||||
@ -178,7 +178,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
act_q_group_shape=self.act_quant_group_shape,
|
act_q_group_shape=self.act_quant_group_shape,
|
||||||
weight_quant_strategy=weight_quant_strategy,
|
weight_quant_strategy=weight_quant_strategy,
|
||||||
out_dtype=self.out_dtype,
|
out_dtype=self.out_dtype,
|
||||||
module_name=self.__class__.__name__
|
module_name=self.__class__.__name__,
|
||||||
)
|
)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user