mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-18 08:06:59 +08:00
Llamas 3.1 405B fp4 changes upstreaming from 355_wip (#25135)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Doug Lehr <douglehr@amd.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
6f97de4e47
commit
c064c82674
16
vllm/envs.py
16
vllm/envs.py
@ -106,6 +106,8 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MLA: bool = True
|
VLLM_ROCM_USE_AITER_MLA: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MHA: bool = True
|
VLLM_ROCM_USE_AITER_MHA: bool = True
|
||||||
|
VLLM_ROCM_USE_AITER_FP4_ASM_GEMM: bool = False
|
||||||
|
VLLM_ROCM_USE_TRITON_ROPE: bool = False
|
||||||
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
VLLM_ROCM_USE_AITER_FP8BMM: bool = True
|
||||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||||
VLLM_ROCM_FP8_PADDING: bool = True
|
VLLM_ROCM_FP8_PADDING: bool = True
|
||||||
@ -934,6 +936,18 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_MHA", "True").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter fp4 gemm asm.
|
||||||
|
# By default is disabled.
|
||||||
|
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_FP4_ASM_GEMM", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
|
# Whether to use aiter rope.
|
||||||
|
# By default is disabled.
|
||||||
|
"VLLM_ROCM_USE_TRITON_ROPE":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_TRITON_ROPE", "False").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# Whether to use aiter triton fp8 bmm kernel
|
# Whether to use aiter triton fp8 bmm kernel
|
||||||
# By default is enabled.
|
# By default is enabled.
|
||||||
"VLLM_ROCM_USE_AITER_FP8BMM":
|
"VLLM_ROCM_USE_AITER_FP8BMM":
|
||||||
@ -1539,6 +1553,8 @@ def compute_hash() -> str:
|
|||||||
"VLLM_ROCM_USE_AITER_RMSNORM",
|
"VLLM_ROCM_USE_AITER_RMSNORM",
|
||||||
"VLLM_ROCM_USE_AITER_MLA",
|
"VLLM_ROCM_USE_AITER_MLA",
|
||||||
"VLLM_ROCM_USE_AITER_MHA",
|
"VLLM_ROCM_USE_AITER_MHA",
|
||||||
|
"VLLM_ROCM_USE_AITER_FP4_ASM_GEMM",
|
||||||
|
"VLLM_ROCM_USE_TRITON_ROPE",
|
||||||
"VLLM_ROCM_USE_AITER_FP8BMM",
|
"VLLM_ROCM_USE_AITER_FP8BMM",
|
||||||
"VLLM_ROCM_USE_SKINNY_GEMM",
|
"VLLM_ROCM_USE_SKINNY_GEMM",
|
||||||
"VLLM_ROCM_FP8_PADDING",
|
"VLLM_ROCM_FP8_PADDING",
|
||||||
|
|||||||
@ -323,6 +323,12 @@ class ReplicatedLinear(LinearBase):
|
|||||||
return_bias: bool = True,
|
return_bias: bool = True,
|
||||||
disable_tp: bool = False,
|
disable_tp: bool = False,
|
||||||
):
|
):
|
||||||
|
# If MergedReplicatedLinear, use output size of each partition.
|
||||||
|
if hasattr(self, "output_sizes"):
|
||||||
|
self.output_partition_sizes = self.output_sizes
|
||||||
|
else:
|
||||||
|
self.output_partition_sizes = [output_size]
|
||||||
|
|
||||||
super().__init__(input_size,
|
super().__init__(input_size,
|
||||||
output_size,
|
output_size,
|
||||||
skip_bias_add,
|
skip_bias_add,
|
||||||
@ -335,7 +341,8 @@ class ReplicatedLinear(LinearBase):
|
|||||||
# All the linear layer supports quant method.
|
# All the linear layer supports quant method.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
self.quant_method.create_weights(self,
|
self.quant_method.create_weights(self,
|
||||||
self.input_size, [self.output_size],
|
self.input_size,
|
||||||
|
self.output_partition_sizes,
|
||||||
self.input_size,
|
self.input_size,
|
||||||
self.output_size,
|
self.output_size,
|
||||||
self.params_dtype,
|
self.params_dtype,
|
||||||
@ -374,12 +381,15 @@ class ReplicatedLinear(LinearBase):
|
|||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
if not self.return_bias:
|
if not self.return_bias:
|
||||||
return output
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
@ -413,7 +423,7 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
output_sizes: list of output sizes packed into one output, like for QKV
|
output_sizes: list of output sizes packed into one output, like for QKV
|
||||||
the list would be size 3.
|
the list would be size 3.
|
||||||
prefix: The name of the layer in the state dict, including all parents
|
prefix: The name of the layer in the state dict, including all parents
|
||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
return_bias: If true, return bias together with outputs in forward pass.
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
disable_tp: If true, weights matrix won't be sharded through tp rank.
|
||||||
"""
|
"""
|
||||||
@ -535,13 +545,15 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_
|
self,
|
||||||
|
input_,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||||
|
|
||||||
if self.gather_output and self.tp_size > 1:
|
if self.gather_output and self.tp_size > 1:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
@ -1326,7 +1338,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_
|
self,
|
||||||
|
input_,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
@ -1340,9 +1353,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
||||||
# bias will not get added more than once in TP>1 case)
|
# bias will not get added more than once in TP>1 case)
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
output_parallel = self.quant_method.apply(self,
|
output_parallel = self.quant_method.apply(self, input_parallel, bias_)
|
||||||
input_parallel,
|
|
||||||
bias=bias_)
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -395,6 +395,7 @@ class QuarkLinearMethod(LinearMethodBase):
|
|||||||
scheme = layer.scheme
|
scheme = layer.scheme
|
||||||
if scheme is None:
|
if scheme is None:
|
||||||
raise ValueError("A scheme must be defined for each layer")
|
raise ValueError("A scheme must be defined for each layer")
|
||||||
|
|
||||||
return scheme.apply_weights(layer, x, bias=bias)
|
return scheme.apply_weights(layer, x, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from functools import cache
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm import envs
|
||||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
||||||
@ -14,7 +15,90 @@ from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
|||||||
PackedvLLMParameter)
|
PackedvLLMParameter)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
@cache
|
||||||
|
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||||
|
return current_platform.is_rocm() \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
|
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||||
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||||
|
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||||
|
|
||||||
|
def gemm_with_dynamic_quant(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||||
|
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||||
|
x_scales: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
M = x.shape[0]
|
||||||
|
if rocm_use_aiter_fp4_asm_gemm:
|
||||||
|
if x_scales is None:
|
||||||
|
# use hip quant kernel for performance
|
||||||
|
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||||
|
else:
|
||||||
|
x_q = x
|
||||||
|
x_s = x_scales
|
||||||
|
|
||||||
|
# 32 alignment is enough for dim0 padding of output for
|
||||||
|
# gemm_a4w4 kernel
|
||||||
|
y = torch.empty((M + 31) // 32 * 32,
|
||||||
|
weight.shape[0],
|
||||||
|
device=x_q.device,
|
||||||
|
dtype=out_dtype)
|
||||||
|
|
||||||
|
gemm_a4w4(x_q,
|
||||||
|
weight,
|
||||||
|
x_s,
|
||||||
|
weight_scale.view(x_s.dtype),
|
||||||
|
y,
|
||||||
|
bpreshuffle=True)
|
||||||
|
return y[:M]
|
||||||
|
else:
|
||||||
|
if x_scales is None:
|
||||||
|
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||||
|
else:
|
||||||
|
x_q = x
|
||||||
|
x_s = x_scales
|
||||||
|
y = torch.empty(x_q.shape[0],
|
||||||
|
weight.shape[0],
|
||||||
|
device=x_q.device,
|
||||||
|
dtype=out_dtype)
|
||||||
|
|
||||||
|
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
def gemm_with_dynamic_quant_fake(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
x_scales: torch.Tensor = None,
|
||||||
|
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||||
|
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.empty((*x.shape[:-1], weight.shape[0]),
|
||||||
|
dtype=out_dtype,
|
||||||
|
device=x.device)
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="gemm_with_dynamic_quant",
|
||||||
|
op_func=gemm_with_dynamic_quant,
|
||||||
|
mutates_args=[],
|
||||||
|
fake_impl=gemm_with_dynamic_quant_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
|
||||||
|
|
||||||
__all__ = ["QuarkW4A4MXFP4"]
|
__all__ = ["QuarkW4A4MXFP4"]
|
||||||
|
|
||||||
@ -27,29 +111,15 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|||||||
self.qscheme = "per_group"
|
self.qscheme = "per_group"
|
||||||
self.weight_quant_spec = weight_quant_spec
|
self.weight_quant_spec = weight_quant_spec
|
||||||
self.input_quant_spec = input_quant_spec
|
self.input_quant_spec = input_quant_spec
|
||||||
|
self.emulate = not current_platform.supports_mx()
|
||||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||||
|
if not self.emulate and (dynamic_mxfp4_quant is None
|
||||||
if self.static_input_scales:
|
or gemm_afp4wfp4 is None):
|
||||||
|
# Currently need these kernels if not emulating
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"QuarkW4A4MXFP4 with static input scales is currently not "
|
f"{self.__class__.__name__} requires AITER to be installed "
|
||||||
"implemented. Please open an issue.")
|
"for non-emulation mode! Please refer to "
|
||||||
|
"https://github.com/ROCm/aiter for installation details.")
|
||||||
if not current_platform.supports_mx():
|
|
||||||
self.emulate = True
|
|
||||||
logger.warning_once(
|
|
||||||
"The current platform does not support native MXFP4 "
|
|
||||||
"computation. Simulated weight dequantization and activation "
|
|
||||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
|
||||||
"layers computed in high precision.")
|
|
||||||
else:
|
|
||||||
self.emulate = True
|
|
||||||
logger.warning_once(
|
|
||||||
"The current platform supports native MXFP4 "
|
|
||||||
"computation, but kernels are not yet integrated in vLLM. "
|
|
||||||
"Simulated weight dequantization and activation "
|
|
||||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
|
||||||
"layers computed in high precision.")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
@ -58,8 +128,65 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
|
||||||
requires_grad=False)
|
if self.emulate:
|
||||||
|
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||||
|
requires_grad=False)
|
||||||
|
try:
|
||||||
|
from quark.torch.export.nn.modules import realquantizer
|
||||||
|
from quark.torch.quantization.config.config import (
|
||||||
|
QuantizationSpec)
|
||||||
|
except ImportError as err:
|
||||||
|
raise ImportError(
|
||||||
|
"The package `amd-quark` is required to use AMD Quark "
|
||||||
|
"MX-FP4 models. Please install it with `pip install "
|
||||||
|
"amd-quark`.") from err
|
||||||
|
|
||||||
|
weight_quant_spec = QuantizationSpec.from_dict(
|
||||||
|
self.weight_quant_spec)
|
||||||
|
|
||||||
|
weight_quantizer = realquantizer.get_real_quantizer(
|
||||||
|
qspec=weight_quant_spec,
|
||||||
|
quantizer=None,
|
||||||
|
real_quantized=True,
|
||||||
|
reorder=False,
|
||||||
|
float_dtype=self.out_dtype,
|
||||||
|
scale_shape=layer.weight_scale.shape,
|
||||||
|
zero_point_shape=None,
|
||||||
|
)
|
||||||
|
weight_quantizer.scale.data = layer.weight_scale.data
|
||||||
|
|
||||||
|
layer.weight = torch.nn.Parameter(
|
||||||
|
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
layer.weight_scale = None
|
||||||
|
|
||||||
|
# This call is necessary to release the scales memory.
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||||
|
# shuffle weight scale
|
||||||
|
weight_scale_shuffle = layer.weight_scale.data
|
||||||
|
sm, sn = weight_scale_shuffle.shape
|
||||||
|
weight_scale_shuffle = weight_scale_shuffle.view(
|
||||||
|
sm // 32, 2, 16, sn // 8, 2, 4, 1)
|
||||||
|
weight_scale_shuffle = weight_scale_shuffle.permute(
|
||||||
|
0, 3, 5, 2, 4, 1, 6).contiguous()
|
||||||
|
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
|
||||||
|
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# shuffle weight
|
||||||
|
weight_shuffle = layer.weight.data
|
||||||
|
weight_shuffle = shuffle_weight(weight_shuffle,
|
||||||
|
layout=(16, 16))
|
||||||
|
layer.weight = torch.nn.Parameter(weight_shuffle,
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.weight_scale = torch.nn.Parameter(
|
||||||
|
layer.weight_scale.data.T.contiguous(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: list[int],
|
output_partition_sizes: list[int],
|
||||||
@ -104,9 +231,9 @@ class QuarkW4A4MXFP4(QuarkScheme):
|
|||||||
|
|
||||||
if self.emulate:
|
if self.emulate:
|
||||||
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
||||||
|
|
||||||
x = quant_dequant_mxfp4(x)
|
x = quant_dequant_mxfp4(x)
|
||||||
|
|
||||||
return F.linear(x, dq_w, bias)
|
return F.linear(x, dq_w, bias)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
return torch.ops.vllm.gemm_with_dynamic_quant(
|
||||||
|
x, layer.weight, layer.weight_scale,
|
||||||
|
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
|
||||||
|
|||||||
@ -8,6 +8,8 @@ import torch
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from .common import apply_rotary_emb_torch
|
from .common import apply_rotary_emb_torch
|
||||||
|
from .rocm_aiter_rope_ops import (is_rocm_triton_rotary_embedding_enabled,
|
||||||
|
rocm_aiter_rotary_emb)
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
@ -45,6 +47,8 @@ class RotaryEmbedding(CustomOp):
|
|||||||
cache = cache.to(dtype)
|
cache = cache.to(dtype)
|
||||||
self.cos_sin_cache: torch.Tensor
|
self.cos_sin_cache: torch.Tensor
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
self.is_rocm_triton_rotary_embedding_enabled = \
|
||||||
|
is_rocm_triton_rotary_embedding_enabled()
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
@ -120,14 +124,31 @@ class RotaryEmbedding(CustomOp):
|
|||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
self._match_cos_sin_cache_dtype(query)
|
self._match_cos_sin_cache_dtype(query)
|
||||||
|
|
||||||
# ops.rotary_embedding() is an in-place operation
|
# ops.rotary_embedding() is an in-place operation
|
||||||
# that updates the query and key tensors.
|
# that updates the query and key tensors.
|
||||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||||
self.cos_sin_cache, self.is_neox_style)
|
self.cos_sin_cache, self.is_neox_style)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: Optional[torch.Tensor] = None,
|
||||||
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if self.is_rocm_triton_rotary_embedding_enabled:
|
||||||
|
self._match_cos_sin_cache_dtype(query)
|
||||||
|
rocm_aiter_rotary_emb(positions, query, key, self.cos_sin_cache,
|
||||||
|
self.head_size, self.rotary_dim,
|
||||||
|
self.is_neox_style)
|
||||||
|
else:
|
||||||
|
# ops.rotary_embedding() is an in-place operation
|
||||||
|
# that updates the query and key tensors.
|
||||||
|
self.forward_cuda(positions, query, key)
|
||||||
|
return query, key
|
||||||
|
|
||||||
def forward_xpu(
|
def forward_xpu(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
|
|||||||
@ -0,0 +1,86 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
def is_rocm_triton_rotary_embedding_enabled() -> bool:
|
||||||
|
return (current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER
|
||||||
|
and envs.VLLM_ROCM_USE_TRITON_ROPE)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_rotary_emb_with_key_forward_triton_impl(
|
||||||
|
positions: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
rotate_style: int = 0,
|
||||||
|
is_nope_first: bool = False,
|
||||||
|
) -> None:
|
||||||
|
import aiter.ops.triton.rope as ops
|
||||||
|
ops.rope_cached_thd_positions_2c_fwd_inplace(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
positions,
|
||||||
|
rotate_style,
|
||||||
|
reuse_freqs_front_part=True,
|
||||||
|
nope_first=is_nope_first,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_rotary_emb_with_key_forward_triton_fake(
|
||||||
|
positions: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
rotate_style: int = 0,
|
||||||
|
is_nope_first: bool = False,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if is_rocm_triton_rotary_embedding_enabled():
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="rocm_aiter_rotary_emb_with_key_forward_triton",
|
||||||
|
op_func=rocm_aiter_rotary_emb_with_key_forward_triton_impl,
|
||||||
|
mutates_args=["key", "query"],
|
||||||
|
fake_impl=rocm_aiter_rotary_emb_with_key_forward_triton_fake,
|
||||||
|
dispatch_key=current_platform.dispatch_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rocm_aiter_rotary_emb(positions: torch.Tensor, query: torch.Tensor,
|
||||||
|
key: torch.Tensor, cos_sin_cache: torch.Tensor,
|
||||||
|
head_size: int, rotary_dim: int,
|
||||||
|
is_neox_style: bool):
|
||||||
|
num_tokens = positions.numel()
|
||||||
|
cos, sin = cos_sin_cache.chunk(2, dim=-1)
|
||||||
|
query_shape = query.shape
|
||||||
|
key_shape = key.shape
|
||||||
|
rotate_style = 0 if is_neox_style else 1
|
||||||
|
|
||||||
|
query = query.view(num_tokens, -1, head_size)
|
||||||
|
key = key.view(num_tokens, -1, head_size)
|
||||||
|
query_ = query[..., :rotary_dim]
|
||||||
|
key_ = key[..., :rotary_dim]
|
||||||
|
positions = positions.view(*query.shape[:1])
|
||||||
|
torch.ops.vllm.rocm_aiter_rotary_emb_with_key_forward_triton(
|
||||||
|
positions,
|
||||||
|
sin,
|
||||||
|
cos,
|
||||||
|
query_,
|
||||||
|
key_,
|
||||||
|
rotate_style,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
query = query.view(query_shape)
|
||||||
|
key = key.view(key_shape)
|
||||||
Loading…
x
Reference in New Issue
Block a user