update unit tests to use ScaledMMLinearKernels

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
vllmellm 2025-11-01 16:28:03 +00:00
parent 4ce0ba2df4
commit dd5a70ec71
12 changed files with 152 additions and 89 deletions

View File

@ -20,8 +20,13 @@ from vllm.config import (
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -35,21 +40,23 @@ class TestSiluMul(torch.nn.Module):
def __init__(self, hidden_size: int = 128): def __init__(self, hidden_size: int = 128):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32) self.weight_scale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
if TEST_FP8: if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=True, act_q_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR, act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
if TEST_FP8: if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) return self.fp8_linear.apply_weights(self, y)
return x2
else: else:
return y return y
@ -81,11 +88,19 @@ class TestFusedAddRMSNorm(torch.nn.Module):
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
if TEST_FP8: if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
self.scale = torch.rand(1, dtype=torch.float32) act_q_group_shape=GroupShape.PER_TENSOR,
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t() weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
self.wscale = torch.rand(1, dtype=torch.float32) out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.weight = (
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
)
self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
def forward(self, hidden_states, residual): def forward(self, hidden_states, residual):
# Reshape input # Reshape input
@ -99,13 +114,9 @@ class TestFusedAddRMSNorm(torch.nn.Module):
norm_output, residual_output = self.norm(mm, residual) norm_output, residual_output = self.norm(mm, residual)
if TEST_FP8: if TEST_FP8:
self.input_scale = self.input_scale.to(norm_output.device)
# scaled_mm with static input quantization # scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply( fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output)
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output return fp8_linear_result, residual_output

View File

@ -18,19 +18,24 @@ from vllm.config import (
VllmConfig, VllmConfig,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
QuantKey, QuantKey,
ScaleDesc, ScaleDesc,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_fp8_supported, cutlass_fp8_supported,
maybe_create_device_identity, maybe_create_device_identity,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from ..utils import override_cutlass_fp8_supported from ..utils import TestFP8Layer, override_cutlass_fp8_supported
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -54,6 +59,8 @@ class TestModel(torch.nn.Module):
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)] self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
quant_scale = ScaleDesc(torch.float32, static, group_shape) quant_scale = ScaleDesc(torch.float32, static, group_shape)
self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True) self.quant_key = QuantKey(dtype=FP8_DTYPE, scale=quant_scale, symmetric=True)
if static: if static:
@ -66,9 +73,12 @@ class TestModel(torch.nn.Module):
] ]
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=static, act_q_static=static,
act_quant_group_shape=group_shape, act_q_group_shape=group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
self.enable_rms_norm_custom_op = self.norm[0].enabled() self.enable_rms_norm_custom_op = self.norm[0].enabled()
@ -79,20 +89,20 @@ class TestModel(torch.nn.Module):
x = resid = torch.relu(x) x = resid = torch.relu(x)
y = self.norm[0](x) y = self.norm[0](x)
x2 = self.fp8_linear.apply( layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0])
y, self.w[0], self.wscale[0], input_scale=self.scale[0] x2 = self.fp8_linear.apply_weights(layer1, y)
)
# make sure resid is used for replacement to work # make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
x3 = self.fp8_linear.apply( layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1])
y2, self.w[1], self.wscale[1], input_scale=self.scale[1] x3 = self.fp8_linear.apply_weights(layer2, y2)
)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
x4 = self.fp8_linear.apply( layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2])
y3, self.w[2], self.wscale[2], input_scale=self.scale[2] x4 = self.fp8_linear.apply_weights(
layer3,
y3,
) )
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here

View File

@ -26,14 +26,19 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape, GroupShape,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from ..utils import has_module_attribute, multi_gpu_test from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from .backend import TestBackend from .backend import TestBackend
@ -81,43 +86,49 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
self.eps = eps self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)] self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)] self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [ self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.weight = [
torch.rand(hidden_size, hidden_size) torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype()) .to(dtype=current_platform.fp8_dtype())
.t() .t()
for _ in range(3) for _ in range(3)
] ]
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=True, act_q_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR, act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
def forward(self, hidden_states): def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly # avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states) z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z) x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x) y = self.norm[0](x)
z2 = self.fp8_linear.apply( layer1 = TestFP8Layer(
y, self.w[0], self.wscale[0], input_scale=self.scale[0] self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0]
) )
z2 = self.fp8_linear.apply_weights(layer1, y)
x2 = tensor_model_parallel_all_reduce(z2) x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid) y2, resid = self.norm[1](x2, resid)
z3 = self.fp8_linear.apply( layer2 = TestFP8Layer(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1] self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1]
) )
z3 = self.fp8_linear.apply(layer2, y2)
x3 = tensor_model_parallel_all_reduce(z3) x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here y3, resid = self.norm[2](x3, resid) # use resid here
z4 = self.fp8_linear.apply( layer3 = TestFP8Layer(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2] self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2]
) )
z4 = self.fp8_linear.apply(layer3, y3)
x4 = tensor_model_parallel_all_reduce(z4) x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here y4, resid = self.norm[3](x4, resid) # use resid here
return y4 return y4

View File

@ -28,16 +28,23 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.forward_context import get_forward_context, set_forward_context from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer from vllm.utils.flashinfer import has_flashinfer
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from ..utils import TestFP8Layer
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
@ -170,11 +177,18 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.fp8_linear = Fp8LinearOp( if self.quant_key.scale.group_shape.is_per_tensor():
act_quant_static=self.quant_key.scale.static, weight_quant_strategy = ScaledMMLinearQuantStrategy.TENSOR
act_quant_group_shape=self.quant_key.scale.group_shape, else:
) weight_quant_strategy = ScaledMMLinearQuantStrategy.CHANNEL
self.fp8_linear = init_fp8_linear_kernel(
act_q_static=self.quant_key.scale.static,
act_q_group_shape=self.quant_key.scale.group_shape,
weight_quant_strategy=weight_quant_strategy,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
hidden_size = self.num_qo_heads * self.head_size hidden_size = self.num_qo_heads * self.head_size
self.w = kwargs.get( self.w = kwargs.get(
"w", "w",
@ -190,12 +204,8 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
"""Forward pass that creates the pattern to be fused.""" """Forward pass that creates the pattern to be fused."""
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
return self.fp8_linear.apply( layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"])
input=attn_output, return self.fp8_linear.apply_weights(layer, attn_output)
weight=self.w["weight"],
weight_scale=self.w["wscale"],
input_scale=self.w["scale"],
)
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):

View File

@ -27,11 +27,17 @@ from vllm.distributed.parallel_state import (
initialize_model_parallel, initialize_model_parallel,
) )
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from ..utils import multi_gpu_test from ..utils import TestFP8Layer, multi_gpu_test
from .backend import TestBackend from .backend import TestBackend
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
@ -107,8 +113,13 @@ class TestQuantModel(torch.nn.Module):
# Initialize weights # Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02) torch.nn.init.normal_(self.gate_proj, std=0.02)
self.fp8_linear = Fp8LinearOp(act_quant_static=True) self.fp8_linear = init_fp8_linear_kernel(
act_q_static=True,
act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
self.scale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32)
# Create a weight that is compatible with torch._scaled_mm, # Create a weight that is compatible with torch._scaled_mm,
# which expects a column-major layout. # which expects a column-major layout.
@ -138,14 +149,9 @@ class TestQuantModel(torch.nn.Module):
# layer normalization # layer normalization
norm_output, residual_output = self.norm(all_reduce, residual) norm_output, residual_output = self.norm(all_reduce, residual)
# scaled_mm with static input quantization # scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply( layer = TestFP8Layer(None, None, self.scale.to(norm_output.device))
norm_output, fp8_linear_result = self.fp8_linear.apply(layer, norm_output)
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
return fp8_linear_result, residual_output return fp8_linear_result, residual_output

View File

@ -24,13 +24,18 @@ from vllm.config import (
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearQuantStrategy,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kNvfp4Quant, kNvfp4Quant,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity, maybe_create_device_identity,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -50,22 +55,26 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
super().__init__() super().__init__()
self.silu_and_mul = SiluAndMul() self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32) self.weight_scale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32) self.input_scale = torch.rand(1, dtype=torch.float32)
self.input_scale_ub = None
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
with override_cutlass_fp8_supported(not cuda_force_torch): with override_cutlass_fp8_supported(not cuda_force_torch):
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=True, act_q_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR, act_q_group_shape=GroupShape.PER_TENSOR,
weight_quant_strategy=ScaledMMLinearQuantStrategy.TENSOR,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
def forward(self, x): def forward(self, x):
y = self.silu_and_mul(x) y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale) x2 = self.fp8_linear.apply_weights(self, y)
return x2 return x2
def ops_in_model_before(self): def ops_in_model_before(self):

View File

@ -1411,3 +1411,14 @@ def flat_product(*iterables: Iterable[Any]):
for element in itertools.product(*iterables): for element in itertools.product(*iterables):
normalized = (e if isinstance(e, tuple) else (e,) for e in element) normalized = (e if isinstance(e, tuple) else (e,) for e in element)
yield tuple(itertools.chain(*normalized)) yield tuple(itertools.chain(*normalized))
class TestFP8Layer(torch.nn.Module):
"""Helper class for ScaledMMLinearKernels."""
def __init__(self, weight, weight_scale, input_scale):
super().__init__()
self.weight_scale = weight_scale
self.weight = weight
self.input_scale = input_scale
self.input_scale_ub = None

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib import hashlib
from typing import Any, Literal, Self from typing import Any, Literal
from pydantic import model_validator from pydantic import model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self
from vllm.config.utils import config from vllm.config.utils import config

View File

@ -33,7 +33,6 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped, is_layer_skipped,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity, maybe_create_device_identity,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
@ -97,9 +96,6 @@ class FBGEMMFp8Config(QuantizationConfig):
class FBGEMMFp8LinearMethod(LinearMethodBase): class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config): def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN
)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
self.fp8_linear_kernel = init_fp8_linear_kernel( self.fp8_linear_kernel = init_fp8_linear_kernel(

View File

@ -41,7 +41,7 @@ class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig): class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_strategy: ScaledMMLinearQuantStrategy weight_quant_strategy: ScaledMMLinearQuantStrategy
activation_group_shape: GroupShape activation_group_shape: GroupShape
out_dtype: torch.dtype out_dtype: torch.dtype | None
_FP8ParamsT = tuple[ _FP8ParamsT = tuple[

View File

@ -64,7 +64,7 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
], ],
} }
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel, covariant=True) _KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig) _KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)

View File

@ -7,11 +7,9 @@ import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.platforms import current_platform from vllm.platforms import current_platform
FP8ScaledMMCallBack = Callable[..., torch.Tensor]
def apply_weights_fp8( def apply_weights_fp8(
scaled_mm_func: FP8ScaledMMCallBack, scaled_mm_func: Callable[..., torch.Tensor],
quant_fp8_func: QuantFP8, quant_fp8_func: QuantFP8,
w: torch.Tensor, w: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,