mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 03:57:02 +08:00
fix merge artifacts
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
8937e96eab
commit
4e488dac33
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.plugins
|
||||
import vllm.config
|
||||
from vllm.compilation.fusion import FUSED_OPS, FusedRMSQuantKey, RMSNormQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_op_nodes
|
||||
from vllm.compilation.matcher_utils import QUANT_OPS
|
||||
@ -35,19 +35,18 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.rocm import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
ScaleDesc,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import cutlass_block_fp8_supported
|
||||
|
||||
from ..utils import TestFP8Layer
|
||||
from ..utils import TestBlockFP8Layer, TestFP8Layer
|
||||
from .backend import TestBackend
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
@ -56,74 +55,23 @@ RMS_OP = torch.ops._C.rms_norm.default
|
||||
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||
|
||||
|
||||
class W8A8Fp8LinearWrapper:
|
||||
"""
|
||||
Wrapper class for W8A8BlockFp8LinearOp that provides a callable interface
|
||||
and the is_quant_fp8_enabled() method required by tests.
|
||||
|
||||
This class creates a W8A8 (weight-8bit, activation-8bit) FP8 linear operation
|
||||
with blockwise quantization support.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
group_shape: GroupShape,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Initialize the W8A8 FP8 linear wrapper.
|
||||
|
||||
Args:
|
||||
group_shape: The quantization group shape for activations.
|
||||
For blockwise quantization, this is typically (1, block_size).
|
||||
weight: The FP8 quantized weight tensor.
|
||||
weight_scale: The per-group scaling factors for dequantizing the weights.
|
||||
input_scale: The per-group scaling factors for quantizing the input activations.
|
||||
Can be None for dynamic quantization.
|
||||
"""
|
||||
# Create the blockwise FP8 linear operator
|
||||
# Note: weight_group_shape uses a square group (group_shape[1], group_shape[1])
|
||||
# to match the expected weight layout for blockwise quantization
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
self.weight = weight
|
||||
self.weight_scale = weight_scale
|
||||
self.input_scale = input_scale
|
||||
|
||||
def __call__(self, input: torch.Tensor) -> torch.Tensor:
|
||||
"""Make the wrapper callable like the original partial function."""
|
||||
return self.linear_op.apply(
|
||||
input=input,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=None
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
"""Check if FP8 quantization custom op is enabled."""
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float,
|
||||
force_kernel: FP8ScaledMMLinearKernel,
|
||||
force_kernel: FP8ScaledMMLinearKernel | None,
|
||||
group_shape: GroupShape,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.group_shape = group_shape
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(4)]
|
||||
static = group_shape == GroupShape.PER_TENSOR
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, static, group_shape)
|
||||
is_static = group_shape == GroupShape.PER_TENSOR
|
||||
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
|
||||
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
|
||||
self.activation_quant_key = QuantKey(
|
||||
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
|
||||
@ -142,29 +90,27 @@ class TestModel(torch.nn.Module):
|
||||
)
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
else: # PER_TOKEN
|
||||
self.wscale = [
|
||||
torch.rand((hidden_size, 1), dtype=torch.float32) for _ in range(3)
|
||||
]
|
||||
|
||||
if static:
|
||||
self.act_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
|
||||
else:
|
||||
self.act_scale = [None for _ in range(3)]
|
||||
self.act_scale = (
|
||||
[torch.rand(1, dtype=torch.float32) for _ in range(3)]
|
||||
if is_static
|
||||
else [None for _ in range(3)]
|
||||
)
|
||||
|
||||
# Initialize weights
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE) for _ in range(3)
|
||||
]
|
||||
|
||||
if not group_shape.is_per_group():
|
||||
self.w = [self.w[0].t() for _ in range(3)]
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
|
||||
if group_shape.is_per_group():
|
||||
self.fp8_linear_layers = [
|
||||
W8A8Fp8LinearWrapper(
|
||||
TestBlockFP8Layer(
|
||||
group_shape=group_shape,
|
||||
weight=self.w[i],
|
||||
weight_scale=self.wscale[i],
|
||||
@ -172,9 +118,6 @@ class TestModel(torch.nn.Module):
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
else:
|
||||
self.fp8_linear_layers = [
|
||||
TestFP8Layer(
|
||||
@ -187,15 +130,11 @@ class TestModel(torch.nn.Module):
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||
self.group_shape = group_shape
|
||||
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
|
||||
0
|
||||
].is_quant_fp8_enabled()
|
||||
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# avoid having graph input be an arg to a pattern directly
|
||||
x = resid = torch.relu(x)
|
||||
@ -249,13 +188,12 @@ CUDA_FP8_KERNELS = [
|
||||
ChannelWiseTorchScaledMMLinearKernel,
|
||||
]
|
||||
|
||||
# Blockwise group shapes that use W8A8BlockFp8LinearOp
|
||||
|
||||
BLOCKWISE_GROUP_SHAPES = [
|
||||
GroupShape(1, 128),
|
||||
GroupShape(1, 64),
|
||||
]
|
||||
|
||||
# Non-blockwise group shapes that use FP8ScaledMMLinearKernel
|
||||
NON_BLOCKWISE_GROUP_SHAPES = [
|
||||
GroupShape.PER_TOKEN,
|
||||
GroupShape.PER_TENSOR,
|
||||
@ -265,27 +203,27 @@ NON_BLOCKWISE_GROUP_SHAPES = [
|
||||
def _generate_kernel_groupshape_combinations():
|
||||
"""
|
||||
Generate valid (kernel, group_shape) combinations for testing.
|
||||
|
||||
Returns:
|
||||
List of (kernel, group_shape) tuples where:
|
||||
- Blockwise group shapes use None as kernel (W8A8BlockFp8LinearOp doesn't use FP8ScaledMMLinearKernel)
|
||||
- Non-blockwise group shapes are paired with compatible kernels
|
||||
"""
|
||||
combinations = []
|
||||
|
||||
kernels = CUDA_FP8_KERNELS if current_platform.is_cuda() else ROCM_FP8_KERNELS
|
||||
|
||||
# Non-blockwise group shapes with FP8ScaledMMLinearKernel
|
||||
for kernel in kernels:
|
||||
for group_shape in NON_BLOCKWISE_GROUP_SHAPES:
|
||||
# PerTensorTorchScaledMMLinearKernel only works with PER_TENSOR
|
||||
if kernel == PerTensorTorchScaledMMLinearKernel and group_shape != GroupShape.PER_TENSOR:
|
||||
if (
|
||||
kernel == PerTensorTorchScaledMMLinearKernel
|
||||
and group_shape != GroupShape.PER_TENSOR
|
||||
):
|
||||
continue
|
||||
# ChannelWiseTorchScaledMMLinearKernel only works with PER_TOKEN
|
||||
if kernel == ChannelWiseTorchScaledMMLinearKernel and group_shape != GroupShape.PER_TOKEN:
|
||||
if (
|
||||
kernel == ChannelWiseTorchScaledMMLinearKernel
|
||||
and group_shape != GroupShape.PER_TOKEN
|
||||
):
|
||||
continue
|
||||
# RowWiseTorchScaledMMLinearKernel only works with PER_TOKEN
|
||||
if kernel == RowWiseTorchScaledMMLinearKernel and group_shape != GroupShape.PER_TOKEN:
|
||||
if (
|
||||
kernel == RowWiseTorchScaledMMLinearKernel
|
||||
and group_shape != GroupShape.PER_TOKEN
|
||||
):
|
||||
continue
|
||||
combinations.append((kernel, group_shape))
|
||||
|
||||
@ -296,7 +234,6 @@ def _generate_kernel_groupshape_combinations():
|
||||
return combinations
|
||||
|
||||
|
||||
# Generate valid combinations of (kernel, group_shape)
|
||||
KERNEL_GROUPSHAPE_COMBINATIONS = _generate_kernel_groupshape_combinations()
|
||||
|
||||
|
||||
@ -360,12 +297,6 @@ def test_fusion_rmsnorm_quant(
|
||||
backend2 = TestBackend(noop_pass, cleanup_pass)
|
||||
model = TestModel(hidden_size, eps, force_kernel, group_shape)
|
||||
|
||||
# # skip the test if we cannot force the kernel for non-blockwise group shapes
|
||||
# if force_kernel is not None:
|
||||
# selected_kernels = [layer.kernel for layer in model.fp8_linear_layers]
|
||||
# if not any(isinstance(kernel, force_kernel) for kernel in selected_kernels):
|
||||
# pytest.skip(f"{force_kernel.__name__} couldn't be forced")
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
@ -48,7 +48,14 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_block_fp8_supported,
|
||||
)
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tokenizers import get_tokenizer
|
||||
@ -1368,3 +1375,44 @@ class TestFP8Layer(torch.nn.Module):
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(self, y, bias)
|
||||
|
||||
|
||||
class TestBlockFP8Layer:
|
||||
"""
|
||||
Test wrapper for W8A8BlockFp8LinearOp to match TestFP8Layer interface.
|
||||
|
||||
This is a workaround until W8A8BlockFp8LinearOp has a similar API to
|
||||
FP8ScaledMMLinearKernel (i.e., a kernel abstraction for blockwise quantization).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_shape: GroupShape,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
):
|
||||
self.kernel = None # For compatibility with TestFP8Layer interface
|
||||
self.linear_op = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(group_shape[1], group_shape[1]),
|
||||
act_quant_group_shape=group_shape,
|
||||
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
|
||||
use_aiter_and_is_supported=False,
|
||||
)
|
||||
self.weight = weight
|
||||
self.weight_scale = weight_scale
|
||||
self.input_scale = input_scale
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, bias: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
return self.linear_op.apply(
|
||||
input=y,
|
||||
weight=self.weight,
|
||||
weight_scale=self.weight_scale,
|
||||
input_scale=self.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def is_quant_fp8_enabled(self) -> bool:
|
||||
return self.linear_op.input_quant_op.enabled()
|
||||
|
||||
@ -177,6 +177,13 @@ class RowWiseTorchScaledMMLinearKernel(TorchScaledMMLinearKernel):
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if c.out_dtype == torch.float16:
|
||||
# hipblaslt rowwise _scaled_mm only supports BFloat16
|
||||
return (
|
||||
False,
|
||||
"RowWiseTorchScaledMMLinearKernel only supports BFloat16.",
|
||||
)
|
||||
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return (
|
||||
False,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user