mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 14:35:45 +08:00
[Bugfix] Fix unstable silu_mul+nvfp4 quant fusion test (#24370)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
parent
a3645ed94d
commit
e68dc2f014
@ -1,9 +1,12 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from tests.kernels.quantization.nvfp4_utils import quant_nvfp4_tensor
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
@ -64,24 +67,27 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
|
|
||||||
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, **kwargs):
|
def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
self.w = torch.randint(256, (hidden_size, hidden_size // 2),
|
|
||||||
dtype=FP4_DTYPE)
|
# create nvfp4 weight
|
||||||
self.wscale = torch.randn(hidden_size,
|
w = torch.rand((hidden_size, hidden_size))
|
||||||
hidden_size // 16).to(dtype=FP8_DTYPE)
|
self.w, self.w_block_scale, self.w_global_scale = quant_nvfp4_tensor(w)
|
||||||
self.wscale2 = torch.rand(1, dtype=torch.float32)
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
# get global scale offline
|
||||||
|
_, _, self.y_global_scale = quant_nvfp4_tensor(self.silu_and_mul(x))
|
||||||
|
|
||||||
|
self.alpha = 1.0 / (self.w_global_scale * self.y_global_scale)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale)
|
y_quant, y_block_scale = scaled_fp4_quant(y, self.y_global_scale)
|
||||||
out = cutlass_scaled_fp4_mm(a=y_quant,
|
out = cutlass_scaled_fp4_mm(a=y_quant,
|
||||||
b=self.w,
|
b=self.w,
|
||||||
block_scale_a=y_block_scale,
|
block_scale_a=y_block_scale,
|
||||||
block_scale_b=self.wscale,
|
block_scale_b=self.w_block_scale,
|
||||||
alpha=self.scale * self.wscale2,
|
alpha=self.alpha,
|
||||||
out_dtype=y.dtype)
|
out_dtype=y.dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -95,8 +101,9 @@ class TestSiluMulNvfp4QuantModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("num_tokens", [64])
|
@pytest.mark.parametrize("num_tokens", [64])
|
||||||
@pytest.mark.parametrize("hidden_size", [128])
|
@pytest.mark.parametrize("hidden_size", [128])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
"model_class",
|
||||||
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel])
|
cast(list[type], [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
|
||||||
|
if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]))
|
||||||
# cuda_force_torch used to test torch code path on platforms that
|
# cuda_force_torch used to test torch code path on platforms that
|
||||||
# cutlass_fp8_supported() == True.
|
# cutlass_fp8_supported() == True.
|
||||||
@pytest.mark.parametrize("cuda_force_torch",
|
@pytest.mark.parametrize("cuda_force_torch",
|
||||||
@ -111,6 +118,8 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
|||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, hidden_size * 2)
|
||||||
|
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = VllmConfig()
|
config = VllmConfig()
|
||||||
config.compilation_config = CompilationConfig(
|
config.compilation_config = CompilationConfig(
|
||||||
@ -119,10 +128,10 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
|||||||
|
|
||||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||||
model = model_class(hidden_size=hidden_size,
|
model = model_class(hidden_size=hidden_size,
|
||||||
cuda_force_torch=cuda_force_torch)
|
cuda_force_torch=cuda_force_torch,
|
||||||
|
x=x)
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
x = torch.rand(num_tokens, hidden_size * 2)
|
|
||||||
torch._dynamo.mark_dynamic(x, 0)
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|
||||||
result = model(x)
|
result = model(x)
|
||||||
@ -131,10 +140,15 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class,
|
|||||||
result2 = model2(x)
|
result2 = model2(x)
|
||||||
|
|
||||||
# Check that it gives the same answer
|
# Check that it gives the same answer
|
||||||
|
if model_class == TestSiluMulFp8QuantModel:
|
||||||
|
atol, rtol = 1e-3, 1e-3
|
||||||
|
elif model_class == TestSiluMulNvfp4QuantModel:
|
||||||
|
atol, rtol = 1e-1, 1e-1
|
||||||
|
|
||||||
torch.testing.assert_close(result[0].to(dtype=torch.float16),
|
torch.testing.assert_close(result[0].to(dtype=torch.float16),
|
||||||
result2[0].to(dtype=torch.float16),
|
result2[0].to(dtype=torch.float16),
|
||||||
atol=1e-3,
|
atol=atol,
|
||||||
rtol=1e-3)
|
rtol=rtol)
|
||||||
|
|
||||||
# In pre-nodes, quant op should be present and fused kernels should not
|
# In pre-nodes, quant op should be present and fused kernels should not
|
||||||
backend.check_before_ops(model.ops_in_model_before())
|
backend.check_before_ops(model.ops_in_model_before())
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm._custom_ops import scaled_fp4_quant
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
@ -65,3 +66,10 @@ def break_fp4_bytes(a, dtype):
|
|||||||
|
|
||||||
# Reshape to final form
|
# Reshape to final form
|
||||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||||
|
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.abs(a).max().to(torch.float32))
|
||||||
|
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||||
|
return a_quant, a_block_scale, a_global_scale
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user