mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-07 10:09:08 +08:00
[Quantization] Add compressed-tensors NVFP4 MoE Support (#19990)
Signed-off-by: Dipika Sikka <dipikasikka1@gmail.com> Signed-off-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
parent
7b1895e6ce
commit
6f2f53a82d
@ -17,7 +17,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
|||||||
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
CompressedTensorsW4A4Fp4, CompressedTensorsW4A16Fp4,
|
||||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||||
CompressedTensorsWNA16)
|
CompressedTensorsWNA16, cutlass_fp4_supported)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
sparse_cutlass_supported)
|
sparse_cutlass_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -668,8 +668,8 @@ def test_compressed_tensors_nvfp4(vllm_runner, args):
|
|||||||
assert isinstance(qkv_proj.quant_method,
|
assert isinstance(qkv_proj.quant_method,
|
||||||
CompressedTensorsLinearMethod)
|
CompressedTensorsLinearMethod)
|
||||||
if isinstance(qkv_proj.scheme, scheme) or isinstance(
|
if isinstance(qkv_proj.scheme, scheme) or isinstance(
|
||||||
qkv_proj.scheme, CompressedTensorsW4A16Fp4
|
qkv_proj.scheme,
|
||||||
) and not CompressedTensorsW4A4Fp4.cutlass_fp4_supported():
|
CompressedTensorsW4A16Fp4) and not cutlass_fp4_supported():
|
||||||
assert True
|
assert True
|
||||||
else:
|
else:
|
||||||
raise AssertionError("FP4 Scheme Mismatch")
|
raise AssertionError("FP4 Scheme Mismatch")
|
||||||
|
|||||||
@ -1246,6 +1246,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
param.materialize(final_shape, dtype=loaded_weight.dtype)
|
||||||
|
|
||||||
expert_data = param.data if full_load else param.data[expert_id]
|
expert_data = param.data if full_load else param.data[expert_id]
|
||||||
|
|
||||||
# Case input scale: input_scale loading is only supported for fp8
|
# Case input scale: input_scale loading is only supported for fp8
|
||||||
if "input_scale" in weight_name:
|
if "input_scale" in weight_name:
|
||||||
# this is needed for compressed-tensors only
|
# this is needed for compressed-tensors only
|
||||||
@ -1273,6 +1274,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_rank=self.tp_rank)
|
tp_rank=self.tp_rank)
|
||||||
return True if return_success else None
|
return True if return_success else None
|
||||||
|
|
||||||
|
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
|
||||||
if "ModelOpt" in quant_method_name:
|
if "ModelOpt" in quant_method_name:
|
||||||
if ('weight_scale_2' in weight_name
|
if ('weight_scale_2' in weight_name
|
||||||
or 'input_scale' in weight_name):
|
or 'input_scale' in weight_name):
|
||||||
@ -1289,7 +1291,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_rank=self.tp_rank)
|
tp_rank=self.tp_rank)
|
||||||
return True if return_success else None
|
return True if return_success else None
|
||||||
|
|
||||||
# Case weight scales, zero_points and offset
|
# Case weight scales, zero_points and offset, weight/input global scales
|
||||||
if ("scale" in weight_name or "zero" in weight_name
|
if ("scale" in weight_name or "zero" in weight_name
|
||||||
or "offset" in weight_name):
|
or "offset" in weight_name):
|
||||||
# load the weight scales and zp based on the quantization scheme
|
# load the weight scales and zp based on the quantization scheme
|
||||||
|
|||||||
@ -33,6 +33,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|||||||
find_matched_target, is_activation_quantization_format,
|
find_matched_target, is_activation_quantization_format,
|
||||||
should_ignore_layer)
|
should_ignore_layer)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||||
|
cutlass_fp4_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -375,7 +377,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
if is_activation_quantization_format(self.quant_format):
|
if is_activation_quantization_format(self.quant_format):
|
||||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
if CompressedTensorsW4A4Fp4.cutlass_fp4_supported(
|
if cutlass_fp4_supported(
|
||||||
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
) or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||||
return CompressedTensorsW4A4Fp4()
|
return CompressedTensorsW4A4Fp4()
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -21,8 +21,12 @@ from vllm.model_executor.layers.quantization.utils import replace_parameter
|
|||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||||
marlin_moe_permute_scales)
|
marlin_moe_permute_scales)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
|
prepare_moe_fp4_layer_for_marlin)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
prepare_moe_fp8_layer_for_marlin)
|
prepare_moe_fp8_layer_for_marlin)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||||
|
cutlass_fp4_supported)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -46,12 +50,11 @@ class GPTQMarlinState(Enum):
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsMoEMethod",
|
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
|
||||||
"CompressedTensorsW8A8Fp8MoEMethod",
|
|
||||||
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
"CompressedTensorsW8A8Fp8MoECutlassMethod",
|
||||||
"CompressedTensorsW8A8Int8MoEMethod",
|
"CompressedTensorsW8A8Int8MoEMethod",
|
||||||
"CompressedTensorsWNA16MarlinMoEMethod",
|
"CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod",
|
||||||
"CompressedTensorsWNA16MoEMethod",
|
"CompressedTensorsW4A4MoeMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -84,6 +87,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
else:
|
else:
|
||||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
||||||
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW4A4MoeMethod()
|
||||||
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
|
elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
|
||||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
@ -95,6 +100,268 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.use_marlin = not cutlass_fp4_supported()
|
||||||
|
self.group_size = 16
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
layer.num_experts = num_experts
|
||||||
|
layer.params_dtype = params_dtype
|
||||||
|
|
||||||
|
w13_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // 2,
|
||||||
|
requires_grad=False,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_packed", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // 2,
|
||||||
|
dtype=torch.uint8),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_packed", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Weight Scales
|
||||||
|
w13_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
2 * intermediate_size_per_partition,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
hidden_size // self.group_size,
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(
|
||||||
|
torch.empty(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
# 2 fp4 items are packed in the input dimension
|
||||||
|
intermediate_size_per_partition // self.group_size,
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.GROUP.value})
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Weight Global Scales
|
||||||
|
w13_weight_scale_2 = torch.nn.Parameter(torch.empty(
|
||||||
|
num_experts, 2, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_weight_scale_2, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight_scale_2 = torch.nn.Parameter(torch.empty(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w2_weight_scale_2, extra_weight_attrs)
|
||||||
|
|
||||||
|
# Input Global Scales
|
||||||
|
w13_input_scale = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_input_global_scale", w13_input_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_input_global_scale", w2_input_scale)
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
def swizzle_blockscale(self, scale: torch.tensor):
|
||||||
|
assert (scale.dtype == torch.float8_e4m3fn)
|
||||||
|
# Pad and blockwise interleave weight_scale
|
||||||
|
scale_ndim = scale.ndim
|
||||||
|
if scale.ndim == 2:
|
||||||
|
scale = scale.unsqueeze(0)
|
||||||
|
assert scale.ndim == 3
|
||||||
|
B, M, K = scale.shape
|
||||||
|
round_up_multiple = lambda x, m: (x + m - 1) // m * m
|
||||||
|
M_padded = round_up_multiple(M, 128)
|
||||||
|
K_padded = round_up_multiple(K, 4)
|
||||||
|
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
|
||||||
|
padded_scale[:B, :M, :K] = scale
|
||||||
|
batches, rows, cols = padded_scale.shape
|
||||||
|
assert rows % 128 == 0
|
||||||
|
assert cols % 4 == 0
|
||||||
|
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
|
||||||
|
cols // 4, 4)
|
||||||
|
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
|
||||||
|
swizzled_scale = swizzled_scale.contiguous().cuda()
|
||||||
|
return (swizzled_scale.reshape(M, K)
|
||||||
|
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|
||||||
|
# From packed to weight
|
||||||
|
layer.w13_weight = torch.nn.Parameter(layer.w13_weight_packed.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.w2_weight = torch.nn.Parameter(layer.w2_weight_packed.data,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
if not torch.allclose(layer.w13_weight_global_scale[:, 0],
|
||||||
|
layer.w13_weight_global_scale[:, 1]):
|
||||||
|
logger.warning_once(
|
||||||
|
"w1_weight_global_scale must match w3_weight_global_scale. "
|
||||||
|
"Accuracy may be affected.")
|
||||||
|
|
||||||
|
# Take inverse of global scale saved to disk
|
||||||
|
layer.w13_weight_scale_2 = torch.nn.Parameter(
|
||||||
|
1 / layer.w13_weight_global_scale[:, 0], requires_grad=False)
|
||||||
|
|
||||||
|
layer.w2_weight_scale_2 = torch.nn.Parameter(
|
||||||
|
1 / layer.w2_weight_global_scale.data, requires_grad=False)
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
prepare_moe_fp4_layer_for_marlin(layer)
|
||||||
|
return
|
||||||
|
|
||||||
|
# swizzle weight scales
|
||||||
|
layer.w13_blockscale_swizzled = torch.nn.Parameter(
|
||||||
|
self.swizzle_blockscale(layer.w13_weight_scale),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.w2_blockscale_swizzled = torch.nn.Parameter(
|
||||||
|
self.swizzle_blockscale(layer.w2_weight_scale),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# w13
|
||||||
|
w13_input_global_scale = layer.w13_input_global_scale.max(
|
||||||
|
dim=1).values.to(torch.float32)
|
||||||
|
|
||||||
|
layer.g1_alphas = torch.nn.Parameter(
|
||||||
|
((1 / w13_input_global_scale) * layer.w13_weight_scale_2),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.w13_input_scale_quant = torch.nn.Parameter(
|
||||||
|
(w13_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
|
# w2
|
||||||
|
layer.g2_alphas = torch.nn.Parameter(
|
||||||
|
((1 / layer.w2_input_global_scale) * layer.w2_weight_scale_2).to(
|
||||||
|
torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||||
|
(layer.w2_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: Optional[torch.Tensor] = None,
|
||||||
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if enable_eplb:
|
||||||
|
raise NotImplementedError("EPLB not supported for "
|
||||||
|
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
|
x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
layer.w13_weight_scale,
|
||||||
|
layer.w2_weight_scale,
|
||||||
|
router_logits,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
global_scale1=layer.w13_weight_scale_2,
|
||||||
|
global_scale2=layer.w2_weight_scale_2,
|
||||||
|
quant_type_id=scalar_types.float4_e2m1f.id,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map)
|
||||||
|
|
||||||
|
assert activation == "silu", "Only SiLU activation is supported."
|
||||||
|
assert not apply_router_weight_on_input, (
|
||||||
|
"Router weight on input is not "
|
||||||
|
"supported for CompressedTensorsW4A4MoeMethod.")
|
||||||
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
|
"is currently not supported for "
|
||||||
|
"CompressedTensorsW4A4MoeMethod.")
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
cutlass_moe_fp4)
|
||||||
|
|
||||||
|
# Cutlass moe takes in activations in BF16/Half precision
|
||||||
|
# and fp4 quantized weights loaded from the checkpoint
|
||||||
|
return cutlass_moe_fp4(a=x,
|
||||||
|
w1_fp4=layer.w13_weight,
|
||||||
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
||||||
|
w1_alphas=layer.g1_alphas,
|
||||||
|
w2_fp4=layer.w2_weight,
|
||||||
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
||||||
|
w2_alphas=layer.g2_alphas,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
m=x.shape[0],
|
||||||
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
|
k=x.shape[1],
|
||||||
|
e=layer.w13_weight.shape[0],
|
||||||
|
a1_gscale=layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=layer.w2_input_scale_quant,
|
||||||
|
device=x.device).to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -5,8 +5,7 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
@ -15,7 +14,6 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
|
|||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
ModelWeightParameter,
|
ModelWeightParameter,
|
||||||
PerTensorScaleParameter)
|
PerTensorScaleParameter)
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,15 +31,6 @@ class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
|||||||
return 80
|
return 80
|
||||||
return 100
|
return 100
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def cutlass_fp4_supported(cls) -> bool:
|
|
||||||
if not current_platform.is_cuda():
|
|
||||||
return False
|
|
||||||
capability_tuple = current_platform.get_device_capability()
|
|
||||||
capability = -1 if capability_tuple is None else capability_tuple.to_int( # noqa: E501
|
|
||||||
)
|
|
||||||
return cutlass_scaled_mm_supports_fp4(capability)
|
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: list[int],
|
output_partition_sizes: list[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
|
|||||||
@ -2,9 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm._custom_ops import cutlass_scaled_mm_supports_fp4
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
__all__ = ["break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant"]
|
__all__ = [
|
||||||
|
"break_fp4_bytes", "dequantize_to_dtype", "ref_nvfp4_quant",
|
||||||
|
"cutlass_fp4_supported"
|
||||||
|
]
|
||||||
|
|
||||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
|
||||||
@ -12,6 +17,14 @@ kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.],
|
|||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_fp4_supported() -> bool:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
capability = -1 if capability_tuple is None else capability_tuple.to_int()
|
||||||
|
return cutlass_scaled_mm_supports_fp4(capability)
|
||||||
|
|
||||||
|
|
||||||
def break_fp4_bytes(a, dtype):
|
def break_fp4_bytes(a, dtype):
|
||||||
assert a.dtype == torch.uint8
|
assert a.dtype == torch.uint8
|
||||||
m, n = a.shape
|
m, n = a.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user