mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 06:57:03 +08:00
update unit tests to use ScaledMMLinearKernels
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
4ce0ba2df4
commit
dd5a70ec71
@ -20,8 +20,13 @@ from vllm.config import (
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -35,21 +40,23 @@ class TestSiluMul(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int = 128):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale_ub = None
|
||||
if TEST_FP8:
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
if TEST_FP8:
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
return x2
|
||||
return self.fp8_linear.apply_weights(self, y)
|
||||
else:
|
||||
return y
|
||||
|
||||
@ -81,11 +88,19 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
if TEST_FP8:
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.weight = (
|
||||
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||
)
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale_ub = None
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
# Reshape input
|
||||
@ -99,13 +114,9 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
||||
norm_output, residual_output = self.norm(mm, residual)
|
||||
|
||||
if TEST_FP8:
|
||||
self.input_scale = self.input_scale.to(norm_output.device)
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
|
||||
@ -18,19 +18,24 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ..utils import override_cutlass_fp8_supported
|
||||
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -54,6 +59,8 @@ class TestModel(torch.nn.Module):
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
|
||||
quant_scale = ScaleDesc(torch.float32, static, group_shape)
|
||||
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
|
||||
if static:
|
||||
@ -66,9 +73,12 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=static,
|
||||
act_q_group_shape=group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
@ -79,20 +89,20 @@ class TestModel(torch.nn.Module):
|
||||
x = resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
)
|
||||
layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0])
|
||||
x2 = self.fp8_linear.apply_weights(layer1, y)
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
)
|
||||
layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1])
|
||||
x3 = self.fp8_linear.apply_weights(layer2, y2)
|
||||
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
x4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2])
|
||||
x4 = self.fp8_linear.apply_weights(
|
||||
layer3,
|
||||
y3,
|
||||
)
|
||||
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
|
||||
@ -26,14 +26,19 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
GroupShape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ..utils import has_module_attribute, multi_gpu_test
|
||||
from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
@ -81,43 +86,49 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
||||
self.eps = eps
|
||||
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
|
||||
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.w = [
|
||||
self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
self.weight = [
|
||||
torch.rand(hidden_size, hidden_size)
|
||||
.to(dtype=current_platform.fp8_dtype())
|
||||
.t()
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
z = torch.relu(hidden_states)
|
||||
x = resid = tensor_model_parallel_all_reduce(z)
|
||||
y = self.norm[0](x)
|
||||
|
||||
z2 = self.fp8_linear.apply(
|
||||
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
|
||||
layer1 = TestFP8Layer(
|
||||
self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0]
|
||||
)
|
||||
z2 = self.fp8_linear.apply_weights(layer1, y)
|
||||
|
||||
x2 = tensor_model_parallel_all_reduce(z2)
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
z3 = self.fp8_linear.apply(
|
||||
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
|
||||
layer2 = TestFP8Layer(
|
||||
self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1]
|
||||
)
|
||||
z3 = self.fp8_linear.apply(layer2, y2)
|
||||
|
||||
x3 = tensor_model_parallel_all_reduce(z3)
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
|
||||
z4 = self.fp8_linear.apply(
|
||||
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
|
||||
layer3 = TestFP8Layer(
|
||||
self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2]
|
||||
)
|
||||
z4 = self.fp8_linear.apply(layer3, y3)
|
||||
|
||||
x4 = tensor_model_parallel_all_reduce(z4)
|
||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||
return y4
|
||||
|
||||
@ -28,16 +28,23 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
FP4_DTYPE = torch.uint8
|
||||
|
||||
@ -170,11 +177,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.quant_key.scale.static,
|
||||
act_quant_group_shape=self.quant_key.scale.group_shape,
|
||||
)
|
||||
if self.quant_key.scale.group_shape.is_per_tensor():
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
|
||||
else:
|
||||
weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=self.quant_key.scale.static,
|
||||
act_q_group_shape=self.quant_key.scale.group_shape,
|
||||
weight_quant_strategy=weight_quant_strategy,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
hidden_size = self.num_qo_heads * self.head_size
|
||||
self.w = kwargs.get(
|
||||
"w",
|
||||
@ -190,12 +204,8 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
"""Forward pass that creates the pattern to be fused."""
|
||||
attn_output = self.attn(q, k, v)
|
||||
return self.fp8_linear.apply(
|
||||
input=attn_output,
|
||||
weight=self.w["weight"],
|
||||
weight_scale=self.w["wscale"],
|
||||
input_scale=self.w["scale"],
|
||||
)
|
||||
layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"])
|
||||
return self.fp8_linear.apply_weights(layer, attn_output)
|
||||
|
||||
|
||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||
|
||||
@ -27,11 +27,17 @@ from vllm.distributed.parallel_state import (
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.system_utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from ..utils import TestFP8Layer, multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -107,8 +113,13 @@ class TestQuantModel(torch.nn.Module):
|
||||
# Initialize weights
|
||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(act_quant_static=True)
|
||||
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
# Create a weight that is compatible with torch._scaled_mm,
|
||||
# which expects a column-major layout.
|
||||
@ -138,14 +149,9 @@ class TestQuantModel(torch.nn.Module):
|
||||
|
||||
# layer normalization
|
||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||
|
||||
# scaled_mm with static input quantization
|
||||
fp8_linear_result = self.fp8_linear.apply(
|
||||
norm_output,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.scale.to(norm_output.device),
|
||||
)
|
||||
layer = TestFP8Layer(None, None, self.scale.to(norm_output.device))
|
||||
fp8_linear_result = self.fp8_linear.apply(layer, norm_output)
|
||||
|
||||
return fp8_linear_result, residual_output
|
||||
|
||||
|
||||
@ -24,13 +24,18 @@ from vllm.config import (
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
init_fp8_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearQuantStrategy,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
kFp8StaticTensorSym,
|
||||
kNvfp4Quant,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
@ -50,22 +55,26 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
||||
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
|
||||
super().__init__()
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||
self.input_scale_ub = None
|
||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
|
||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
self.fp8_linear = init_fp8_linear_kernel(
|
||||
act_q_static=True,
|
||||
act_q_group_shape=GroupShape.PER_TENSOR,
|
||||
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
|
||||
out_dtype=torch.get_default_dtype(),
|
||||
module_name=self.__class__.__name__,
|
||||
)
|
||||
|
||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
|
||||
x2 = self.fp8_linear.apply_weights(self, y)
|
||||
return x2
|
||||
|
||||
def ops_in_model_before(self):
|
||||
|
||||
@ -1411,3 +1411,14 @@ def flat_product(*iterables: Iterable[Any]):
|
||||
for element in itertools.product(*iterables):
|
||||
normalized = (e if isinstance(e, tuple) else (e,) for e in element)
|
||||
yield tuple(itertools.chain(*normalized))
|
||||
|
||||
|
||||
class TestFP8Layer(torch.nn.Module):
|
||||
"""Helper class for ScaledMMLinearKernels."""
|
||||
|
||||
def __init__(self, weight, weight_scale, input_scale):
|
||||
super().__init__()
|
||||
self.weight_scale = weight_scale
|
||||
self.weight = weight
|
||||
self.input_scale = input_scale
|
||||
self.input_scale_ub = None
|
||||
|
||||
@ -2,10 +2,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
maybe_create_device_identity,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
@ -97,9 +96,6 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: FBGEMMFp8Config):
|
||||
self.quant_config = quant_config
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
self.fp8_linear_kernel = init_fp8_linear_kernel(
|
||||
|
||||
@ -41,7 +41,7 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_strategy: ScaledMMLinearQuantStrategy
|
||||
activation_group_shape: GroupShape
|
||||
out_dtype: torch.dtype
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
|
||||
_FP8ParamsT = tuple[
|
||||
|
||||
@ -64,7 +64,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True)
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
|
||||
@ -7,11 +7,9 @@ import torch
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
|
||||
|
||||
|
||||
def apply_weights_fp8(
|
||||
scaled_mm_func: FP8ScaledMMCallBack,
|
||||
scaled_mm_func: Callable[..., torch.Tensor],
|
||||
quant_fp8_func: QuantFP8,
|
||||
w: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user