Llamas 3.1 405B fp4 changes upstreaming from 355_wip (#25135)

Signed-off-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
Co-authored-by: Doug Lehr <douglehr@amd.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Aleksandr Malyshev 2025-09-25 18:16:53 -07:00 committed by yewentao256
parent 6f97de4e47
commit c064c82674
6 changed files with 301 additions and 38 deletions

View File

@ -106,6 +106,8 @@ if TYPE_CHECKING:
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
VLLM_ROCM_USE_AITER_MLA: bool = True
VLLM_ROCM_USE_AITER_MHA: bool = True
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
VLLM_ROCM_USE_TRITON_ROPE: bool = False
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
VLLM_ROCM_FP8_PADDING: bool = True
@ -934,6 +936,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
("true", "1")),
# Whether to use aiter fp4 gemm asm.
# By default is disabled.
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in
("true", "1")),
# Whether to use aiter rope.
# By default is disabled.
"VLLM_ROCM_USE_TRITON_ROPE":
lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in
("true", "1")),
# Whether to use aiter triton fp8 bmm kernel
# By default is enabled.
"VLLM_ROCM_USE_AITER_FP8BMM":
@ -1539,6 +1553,8 @@ def compute_hash() -> str:
"VLLM_ROCM_USE_AITER_RMSNORM",
"VLLM_ROCM_USE_AITER_MLA",
"VLLM_ROCM_USE_AITER_MHA",
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
"VLLM_ROCM_USE_TRITON_ROPE",
"VLLM_ROCM_USE_AITER_FP8BMM",
"VLLM_ROCM_USE_SKINNY_GEMM",
"VLLM_ROCM_FP8_PADDING",

View File

@ -323,6 +323,12 @@ class ReplicatedLinear(LinearBase):
return_bias: bool = True,
disable_tp: bool = False,
):
# If MergedReplicatedLinear, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = self.output_sizes
else:
self.output_partition_sizes = [output_size]
super().__init__(input_size,
output_size,
skip_bias_add,
@ -335,7 +341,8 @@ class ReplicatedLinear(LinearBase):
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_partition_sizes,
self.input_size,
self.output_size,
self.params_dtype,
@ -374,12 +381,15 @@ class ReplicatedLinear(LinearBase):
param.data.copy_(loaded_weight)
def forward(
self, x: torch.Tensor
self,
x: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase):
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
disable_tp: If true, weights matrix won't be sharded through tp rank.
"""
@ -535,13 +545,15 @@ class ColumnParallelLinear(LinearBase):
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(
self, input_
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
@ -1326,7 +1338,8 @@ class RowParallelLinear(LinearBase):
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(
self, input_
self,
input_,
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
@ -1340,9 +1353,8 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:

View File

@ -395,6 +395,7 @@ class QuarkLinearMethod(LinearMethodBase):
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import Any, Callable, Optional
import torch
import torch.nn.functional as F
from vllm.logger import init_logger
from vllm import envs
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
@ -14,7 +15,90 @@ from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@cache
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
return current_platform.is_rocm() \
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
and envs.VLLM_ROCM_USE_AITER
try:
from aiter.ops.shuffle import shuffle_weight
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
from aiter.ops.triton.quant import dynamic_mxfp4_quant
from vllm.utils import direct_register_custom_op
if is_rocm_aiter_fp4_asm_gemm_enabled():
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
def gemm_with_dynamic_quant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
x_scales: Optional[torch.Tensor] = None,
) -> torch.Tensor:
M = x.shape[0]
if rocm_use_aiter_fp4_asm_gemm:
if x_scales is None:
# use hip quant kernel for performance
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
else:
x_q = x
x_s = x_scales
# 32 alignment is enough for dim0 padding of output for
# gemm_a4w4 kernel
y = torch.empty((M + 31) // 32 * 32,
weight.shape[0],
device=x_q.device,
dtype=out_dtype)
gemm_a4w4(x_q,
weight,
x_s,
weight_scale.view(x_s.dtype),
y,
bpreshuffle=True)
return y[:M]
else:
if x_scales is None:
x_q, x_s = dynamic_mxfp4_quant(x)
else:
x_q = x
x_s = x_scales
y = torch.empty(x_q.shape[0],
weight.shape[0],
device=x_q.device,
dtype=out_dtype)
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
return y
def gemm_with_dynamic_quant_fake(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
x_scales: torch.Tensor = None,
rocm_use_aiter_fp4_asm_gemm: bool = False,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
return torch.empty((*x.shape[:-1], weight.shape[0]),
dtype=out_dtype,
device=x.device)
direct_register_custom_op(
op_name="gemm_with_dynamic_quant",
op_func=gemm_with_dynamic_quant,
mutates_args=[],
fake_impl=gemm_with_dynamic_quant_fake,
dispatch_key=current_platform.dispatch_key,
)
except ImportError:
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
__all__ = ["QuarkW4A4MXFP4"]
@ -27,29 +111,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
self.qscheme = "per_group"
self.weight_quant_spec = weight_quant_spec
self.input_quant_spec = input_quant_spec
self.static_input_scales = not input_quant_spec.get("is_dynamic")
if self.static_input_scales:
self.emulate = not current_platform.supports_mx()
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
if not self.emulate and (dynamic_mxfp4_quant is None
or gemm_afp4wfp4 is None):
# Currently need these kernels if not emulating
raise NotImplementedError(
"QuarkW4A4MXFP4 with static input scales is currently not "
"implemented. Please open an issue.")
if not current_platform.supports_mx():
self.emulate = True
logger.warning_once(
"The current platform does not support native MXFP4 "
"computation. Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
else:
self.emulate = True
logger.warning_once(
"The current platform supports native MXFP4 "
"computation, but kernels are not yet integrated in vLLM. "
"Simulated weight dequantization and activation "
"QDQ (quantize and dequantize) will be used, with the linear "
"layers computed in high precision.")
f"{self.__class__.__name__} requires AITER to be installed "
"for non-emulation mode! Please refer to "
"https://github.com/ROCm/aiter for installation details.")
@classmethod
def get_min_capability(cls) -> int:
@ -58,8 +128,65 @@ class QuarkW4A4MXFP4(QuarkScheme):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,
requires_grad=False)
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
if self.emulate:
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
try:
from quark.torch.export.nn.modules import realquantizer
from quark.torch.quantization.config.config import (
QuantizationSpec)
except ImportError as err:
raise ImportError(
"The package `amd-quark` is required to use AMD Quark "
"MX-FP4 models. Please install it with `pip install "
"amd-quark`.") from err
weight_quant_spec = QuantizationSpec.from_dict(
self.weight_quant_spec)
weight_quantizer = realquantizer.get_real_quantizer(
qspec=weight_quant_spec,
quantizer=None,
real_quantized=True,
reorder=False,
float_dtype=self.out_dtype,
scale_shape=layer.weight_scale.shape,
zero_point_shape=None,
)
weight_quantizer.scale.data = layer.weight_scale.data
layer.weight = torch.nn.Parameter(
weight_quantizer(layer.weight.data).to(self.out_dtype),
requires_grad=False,
)
layer.weight_scale = None
# This call is necessary to release the scales memory.
torch.cuda.empty_cache()
else:
if self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
sm, sn = weight_scale_shuffle.shape
weight_scale_shuffle = weight_scale_shuffle.view(
sm // 32, 2, 16, sn // 8, 2, 4, 1)
weight_scale_shuffle = weight_scale_shuffle.permute(
0, 3, 5, 2, 4, 1, 6).contiguous()
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
requires_grad=False)
# shuffle weight
weight_shuffle = layer.weight.data
weight_shuffle = shuffle_weight(weight_shuffle,
layout=(16, 16))
layer.weight = torch.nn.Parameter(weight_shuffle,
requires_grad=False)
else:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(),
requires_grad=False)
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: list[int],
@ -104,9 +231,9 @@ class QuarkW4A4MXFP4(QuarkScheme):
if self.emulate:
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
x = quant_dequant_mxfp4(x)
return F.linear(x, dq_w, bias)
else:
raise NotImplementedError()
return torch.ops.vllm.gemm_with_dynamic_quant(
x, layer.weight, layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)

View File

@ -8,6 +8,8 @@ import torch
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .rocm_aiter_rope_ops import (is_rocm_triton_rotary_embedding_enabled,
rocm_aiter_rotary_emb)
@CustomOp.register("rotary_embedding")
@ -45,6 +47,8 @@ class RotaryEmbedding(CustomOp):
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
self.is_rocm_triton_rotary_embedding_enabled = \
is_rocm_triton_rotary_embedding_enabled()
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
@ -120,14 +124,31 @@ class RotaryEmbedding(CustomOp):
return query, key
from vllm import _custom_ops as ops
self._match_cos_sin_cache_dtype(query)
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if self.is_rocm_triton_rotary_embedding_enabled:
self._match_cos_sin_cache_dtype(query)
rocm_aiter_rotary_emb(positions, query, key, self.cos_sin_cache,
self.head_size, self.rotary_dim,
self.is_neox_style)
else:
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
self.forward_cuda(positions, query, key)
return query, key
def forward_xpu(
self,
positions: torch.Tensor,

View File

@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_triton_rotary_embedding_enabled() -> bool:
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
and envs.VLLM_ROCM_USE_TRITON_ROPE)
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
import aiter.ops.triton.rope as ops
ops.rope_cached_thd_positions_2c_fwd_inplace(
query,
key,
cos,
sin,
positions,
rotate_style,
reuse_freqs_front_part=True,
nope_first=is_nope_first,
)
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
positions: torch.Tensor,
sin: torch.Tensor,
cos: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
rotate_style: int = 0,
is_nope_first: bool = False,
) -> None:
pass
if is_rocm_triton_rotary_embedding_enabled():
direct_register_custom_op(
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
mutates_args=["key", "query"],
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
dispatch_key=current_platform.dispatch_key,
)
def rocm_aiter_rotary_emb(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, cos_sin_cache: torch.Tensor,
head_size: int, rotary_dim: int,
is_neox_style: bool):
num_tokens = positions.numel()
cos, sin = cos_sin_cache.chunk(2, dim=-1)
query_shape = query.shape
key_shape = key.shape
rotate_style = 0 if is_neox_style else 1
query = query.view(num_tokens, -1, head_size)
key = key.view(num_tokens, -1, head_size)
query_ = query[..., :rotary_dim]
key_ = key[..., :rotary_dim]
positions = positions.view(*query.shape[:1])
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
positions,
sin,
cos,
query_,
key_,
rotate_style,
False,
)
query = query.view(query_shape)
key = key.view(key_shape)