mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-23 22:44:33 +08:00
reduce test boilerplate
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
This commit is contained in:
parent
fb72ec8218
commit
93fb7071f5
@ -20,14 +20,12 @@ from vllm.config import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from ..utils import TestFP8Layer
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
@ -43,20 +41,14 @@ class TestSiluMul(torch.nn.Module):
|
|||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale_ub = None
|
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.weight,
|
||||||
activation_quant_key=self.quant_key,
|
self.weight_scale, self.input_scale)
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
return self.fp8_linear.apply_weights(self, y)
|
return self.fp8_linear(y)
|
||||||
else:
|
else:
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@ -90,18 +82,13 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
|
||||||
activation_quant_key=self.quant_key,
|
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
self.weight = (
|
self.weight = (
|
||||||
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
)
|
)
|
||||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale_ub = None
|
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key,
|
||||||
|
self.weight, self.weight_scale, self.input_scale)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
# Reshape input
|
# Reshape input
|
||||||
@ -117,7 +104,7 @@ class TestFusedAddRMSNorm(torch.nn.Module):
|
|||||||
if TEST_FP8:
|
if TEST_FP8:
|
||||||
self.input_scale = self.input_scale.to(norm_output.device)
|
self.input_scale = self.input_scale.to(norm_output.device)
|
||||||
# scaled_mm with static input quantization
|
# scaled_mm with static input quantization
|
||||||
fp8_linear_result = self.fp8_linear.apply_weights(self, norm_output)
|
fp8_linear_result = self.fp8_linear(norm_output)
|
||||||
|
|
||||||
return fp8_linear_result, residual_output
|
return fp8_linear_result, residual_output
|
||||||
|
|
||||||
|
|||||||
@ -18,9 +18,7 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
GroupShape,
|
GroupShape,
|
||||||
QuantKey,
|
QuantKey,
|
||||||
@ -76,36 +74,30 @@ class TestModel(torch.nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear_1 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key,
|
||||||
activation_quant_key=self.activation_quant_key,
|
self.w[0], self.wscale[0], self.scale[0])
|
||||||
weight_quant_key=self.weight_quant_key,
|
self.fp8_linear_2 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key,
|
||||||
out_dtype=torch.get_default_dtype(),
|
self.w[1], self.wscale[1], self.scale[1])
|
||||||
module_name=self.__class__.__name__,
|
self.fp8_linear_3 = TestFP8Layer(self.activation_quant_key, self.weight_quant_key,
|
||||||
)
|
self.w[2], self.wscale[2], self.scale[2])
|
||||||
|
|
||||||
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
self.enable_rms_norm_custom_op = self.norm[0].enabled()
|
||||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# avoid having graph input be an arg to a pattern directly
|
# avoid having graph input be an arg to a pattern directly
|
||||||
x = resid = torch.relu(x)
|
x = resid = torch.relu(x)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
layer1 = TestFP8Layer(self.w[0], self.wscale[0], input_scale=self.scale[0])
|
x2 = self.fp8_linear_1(y)
|
||||||
x2 = self.fp8_linear.apply_weights(layer1, y)
|
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
layer2 = TestFP8Layer(self.w[1], self.wscale[1], input_scale=self.scale[1])
|
x3 = self.fp8_linear_2(y2)
|
||||||
x3 = self.fp8_linear.apply_weights(layer2, y2)
|
|
||||||
|
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
|
|
||||||
layer3 = TestFP8Layer(self.w[2], self.wscale[2], input_scale=self.scale[2])
|
x4 = self.fp8_linear_3(y3)
|
||||||
x4 = self.fp8_linear.apply_weights(
|
|
||||||
layer3,
|
|
||||||
y3,
|
|
||||||
)
|
|
||||||
|
|
||||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||||
return y4
|
return y4
|
||||||
|
|||||||
@ -26,9 +26,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
@ -93,12 +91,14 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
for _ in range(3)
|
for _ in range(3)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear_1 = TestFP8Layer(self.quant_key,self.quant_key,
|
||||||
activation_quant_key=self.quant_key,
|
self.weight[0],self.wscale[0], input_scale=self.input_scale[0])
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
self.fp8_linear_2 = TestFP8Layer(self.quant_key,self.quant_key,
|
||||||
module_name=self.__class__.__name__,
|
self.weight[1],self.wscale[1], input_scale=self.input_scale[1])
|
||||||
)
|
|
||||||
|
self.fp8_linear_3 = TestFP8Layer(self.quant_key, self.quant_key,
|
||||||
|
self.weight[2], self.wscale[2],input_scale=self.input_scale[2])
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
# avoid having graph input be an arg to a pattern directly
|
# avoid having graph input be an arg to a pattern directly
|
||||||
@ -106,26 +106,18 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
x = resid = tensor_model_parallel_all_reduce(z)
|
x = resid = tensor_model_parallel_all_reduce(z)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
layer1 = TestFP8Layer(
|
|
||||||
self.weight[0], self.weight_scale[0], input_scale=self.input_scale[0]
|
z2 = self.fp8_linear_1(y)
|
||||||
)
|
|
||||||
z2 = self.fp8_linear.apply_weights(layer1, y)
|
|
||||||
|
|
||||||
x2 = tensor_model_parallel_all_reduce(z2)
|
x2 = tensor_model_parallel_all_reduce(z2)
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
layer2 = TestFP8Layer(
|
z3 = self.fp8_linear_2(y2)
|
||||||
self.weight[1], self.weight_scale[1], input_scale=self.input_scale[1]
|
|
||||||
)
|
|
||||||
z3 = self.fp8_linear.apply(layer2, y2)
|
|
||||||
|
|
||||||
x3 = tensor_model_parallel_all_reduce(z3)
|
x3 = tensor_model_parallel_all_reduce(z3)
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
|
|
||||||
layer3 = TestFP8Layer(
|
z4 = self.fp8_linear_3(y3)
|
||||||
self.weight[2], self.weight_scale[2], input_scale=self.input_scale[2]
|
|
||||||
)
|
|
||||||
z4 = self.fp8_linear.apply(layer3, y3)
|
|
||||||
|
|
||||||
x4 = tensor_model_parallel_all_reduce(z4)
|
x4 = tensor_model_parallel_all_reduce(z4)
|
||||||
y4, resid = self.norm[3](x4, resid) # use resid here
|
y4, resid = self.norm[3](x4, resid) # use resid here
|
||||||
@ -138,7 +130,7 @@ class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
|
|||||||
return [
|
return [
|
||||||
torch.ops.vllm.all_reduce.default,
|
torch.ops.vllm.all_reduce.default,
|
||||||
torch.ops._C.static_scaled_fp8_quant.default
|
torch.ops._C.static_scaled_fp8_quant.default
|
||||||
if self.fp8_linear.quant_fp8.enabled()
|
if self.fp8_linear.is_quant_fp8_enabled()
|
||||||
else torch.ops.aten.reciprocal.default,
|
else torch.ops.aten.reciprocal.default,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -28,9 +28,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.forward_context import get_forward_context, set_forward_context
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
QuantKey,
|
QuantKey,
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
@ -174,12 +172,6 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
|
||||||
activation_quant_key=self.quant_key,
|
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_size = self.num_qo_heads * self.head_size
|
hidden_size = self.num_qo_heads * self.head_size
|
||||||
self.w = kwargs.get(
|
self.w = kwargs.get(
|
||||||
@ -192,12 +184,13 @@ class TestAttentionFp8StaticQuantPatternModel(AttentionQuantPatternModel):
|
|||||||
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
"scale": torch.tensor([1.0], dtype=torch.float32, device=self.device),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key, self.w["weight"],
|
||||||
|
self.w["wscale"], self.w["scale"])
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
"""Forward pass that creates the pattern to be fused."""
|
"""Forward pass that creates the pattern to be fused."""
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
layer = TestFP8Layer(self.w["weight"], self.w["wscale"], self.w["scale"])
|
return self.fp8_linear(attn_output)
|
||||||
return self.fp8_linear.apply_weights(layer, attn_output)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
|
||||||
|
|||||||
@ -27,9 +27,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
)
|
)
|
||||||
@ -114,17 +112,14 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
# Initialize weights
|
# Initialize weights
|
||||||
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
torch.nn.init.normal_(self.gate_proj, std=0.02)
|
||||||
|
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
|
||||||
activation_quant_key=self.quant_key,
|
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
self.scale = torch.rand(1, dtype=torch.float32)
|
self.scale = torch.rand(1, dtype=torch.float32)
|
||||||
# Create a weight that is compatible with torch._scaled_mm,
|
# Create a weight that is compatible with torch._scaled_mm,
|
||||||
# which expects a column-major layout.
|
# which expects a column-major layout.
|
||||||
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
|
||||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||||
|
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key,
|
||||||
|
self.w, self.wscale, self.scale)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, hidden_states, residual):
|
def forward(self, hidden_states, residual):
|
||||||
"""
|
"""
|
||||||
@ -150,8 +145,7 @@ class TestQuantModel(torch.nn.Module):
|
|||||||
# layer normalization
|
# layer normalization
|
||||||
norm_output, residual_output = self.norm(all_reduce, residual)
|
norm_output, residual_output = self.norm(all_reduce, residual)
|
||||||
# scaled_mm with static input quantization
|
# scaled_mm with static input quantization
|
||||||
layer = TestFP8Layer(None, None, self.scale.to(norm_output.device))
|
fp8_linear_result = self.fp8_linear(norm_output)
|
||||||
fp8_linear_result = self.fp8_linear.apply(layer, norm_output)
|
|
||||||
|
|
||||||
return fp8_linear_result, residual_output
|
return fp8_linear_result, residual_output
|
||||||
|
|
||||||
|
|||||||
@ -24,9 +24,7 @@ from vllm.config import (
|
|||||||
set_current_vllm_config,
|
set_current_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import SiluAndMul
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
|
||||||
init_fp8_linear_kernel,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
kFp8StaticTensorSym,
|
kFp8StaticTensorSym,
|
||||||
kNvfp4Quant,
|
kNvfp4Quant,
|
||||||
@ -36,7 +34,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ..utils import override_cutlass_fp8_supported
|
from ..utils import TestFP8Layer, override_cutlass_fp8_supported
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
FP8_DTYPE = current_platform.fp8_dtype()
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
@ -55,22 +53,19 @@ class TestSiluMulFp8QuantModel(torch.nn.Module):
|
|||||||
self.silu_and_mul = SiluAndMul()
|
self.silu_and_mul = SiluAndMul()
|
||||||
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
self.weight_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale = torch.rand(1, dtype=torch.float32)
|
self.input_scale = torch.rand(1, dtype=torch.float32)
|
||||||
self.input_scale_ub = None
|
|
||||||
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
|
||||||
with override_cutlass_fp8_supported(not cuda_force_torch):
|
with override_cutlass_fp8_supported(not cuda_force_torch):
|
||||||
self.fp8_linear = init_fp8_linear_kernel(
|
self.fp8_linear = TestFP8Layer(self.quant_key, self.quant_key,
|
||||||
activation_quant_key=self.quant_key,
|
self.weight, self.weight_scale, self.input_scale)
|
||||||
weight_quant_key=self.quant_key,
|
|
||||||
out_dtype=torch.get_default_dtype(),
|
|
||||||
module_name=self.__class__.__name__,
|
|
||||||
)
|
|
||||||
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
|
||||||
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()
|
self.enable_quant_fp8_custom_op = self.fp8_linear.is_quant_fp8_enabled()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = self.silu_and_mul(x)
|
y = self.silu_and_mul(x)
|
||||||
x2 = self.fp8_linear.apply_weights(self, y)
|
x2 = self.fp8_linear(y)
|
||||||
return x2
|
return x2
|
||||||
|
|
||||||
def ops_in_model_before(self):
|
def ops_in_model_before(self):
|
||||||
|
|||||||
@ -49,6 +49,8 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|||||||
from vllm.utils.mem_constants import GB_bytes
|
from vllm.utils.mem_constants import GB_bytes
|
||||||
from vllm.utils.network_utils import get_open_port
|
from vllm.utils.network_utils import get_open_port
|
||||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import init_fp8_linear_kernel
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
from amdsmi import (
|
from amdsmi import (
|
||||||
@ -1414,11 +1416,45 @@ def flat_product(*iterables: Iterable[Any]):
|
|||||||
|
|
||||||
|
|
||||||
class TestFP8Layer(torch.nn.Module):
|
class TestFP8Layer(torch.nn.Module):
|
||||||
"""Helper class for ScaledMMLinearKernels."""
|
"""
|
||||||
|
Test helper class for evaluating FP8 linear operations with quantization.
|
||||||
|
|
||||||
def __init__(self, weight, weight_scale, input_scale):
|
It supports configurable activation and weight quantization parameters,
|
||||||
|
and provides a forward method that applies the FP8 linear transformation
|
||||||
|
with optional bias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation_quant_key (QuantKey): Key for activation quantization configuration.
|
||||||
|
weight_quant_key (QuantKey): Key for weight quantization configuration.
|
||||||
|
weight (torch.Tensor): Weight tensor for linear transformation.
|
||||||
|
weight_scale (torch.Tensor): Per-tensor or per-group scale for weights.
|
||||||
|
input_scale (torch.Tensor): Scale tensor for input quantization.
|
||||||
|
out_dtype (torch.dtype, optional): Output tensor data type. Defaults to torch.get_default_dtype().
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
activation_quant_key: QuantKey,
|
||||||
|
weight_quant_key: QuantKey,
|
||||||
|
weight:torch.Tensor,
|
||||||
|
weight_scale:torch.Tensor,
|
||||||
|
input_scale:torch.Tensor,
|
||||||
|
out_dtype: torch.dtype = torch.get_default_dtype()
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight_scale = weight_scale
|
self.weight_scale = weight_scale
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.input_scale = input_scale
|
self.input_scale = input_scale
|
||||||
self.input_scale_ub = None
|
self.input_scale_ub = None
|
||||||
|
|
||||||
|
self.kernel = init_fp8_linear_kernel(
|
||||||
|
activation_quant_key=activation_quant_key,
|
||||||
|
weight_quant_key=weight_quant_key,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
module_name=self.__class__.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_quant_fp8_enabled(self) -> bool:
|
||||||
|
return self.kernel.quant_fp8.enabled()
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor, bias: torch.Tensor | None=None) -> torch.Tensor:
|
||||||
|
return self.kernel.apply_weights(self, y, bias)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user