mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:05:01 +08:00
Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
209 lines
6.5 KiB
Python
209 lines
6.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
from torch._higher_order_ops import auto_functionalized
|
|
from torch._ops import OpOverload
|
|
|
|
from vllm.config import get_current_vllm_config
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
QuantKey,
|
|
_normalize_quant_group_shape,
|
|
kFp8DynamicTensorSym,
|
|
kFp8DynamicTokenSym,
|
|
kFp8StaticTensorSym,
|
|
kNvfp4Quant,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
RMS_OP = torch.ops._C.rms_norm.default
|
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
|
|
|
QUANT_OPS: dict[QuantKey, OpOverload] = {
|
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa: E501
|
|
kFp8DynamicTensorSym: torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa: E501
|
|
kFp8DynamicTokenSym: torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa: E501
|
|
}
|
|
|
|
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
|
|
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501
|
|
|
|
|
|
class MatcherCustomOp(ABC):
|
|
def __init__(self, enabled: bool):
|
|
config = get_current_vllm_config()
|
|
self.model_dtype = config.model_config.dtype if config.model_config else None
|
|
self.device = config.device_config.device if config.device_config else None
|
|
|
|
self.enabled = enabled
|
|
self.forward = self.forward_custom if enabled else self.forward_native
|
|
|
|
@abstractmethod
|
|
def forward_custom(self, *args, **kws):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def forward_native(self, *args, **kws):
|
|
pass
|
|
|
|
def __call__(self, *args, **kws):
|
|
return self.forward(*args, **kws)
|
|
|
|
def empty(self, *args, **kws):
|
|
return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kws)
|
|
|
|
def empty_f32(self, *args, **kws):
|
|
return torch.empty(*args, dtype=torch.float32, device=self.device, **kws)
|
|
|
|
def inputs(self) -> list[torch.Tensor]:
|
|
"""Utility for inputs to the pattern"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class MatcherRMSNorm(MatcherCustomOp):
|
|
def __init__(self, epsilon: float, enabled: bool | None = None):
|
|
if enabled is None:
|
|
enabled = RMSNorm.enabled()
|
|
|
|
super().__init__(enabled)
|
|
self.epsilon = epsilon
|
|
|
|
def inputs(self):
|
|
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
|
weight = self.empty(16)
|
|
return [input, weight]
|
|
|
|
def forward_custom(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
result = torch.empty_like(input)
|
|
_, result = auto_functionalized(
|
|
RMS_OP,
|
|
result=result,
|
|
input=input,
|
|
weight=weight,
|
|
epsilon=self.epsilon,
|
|
)
|
|
|
|
return result
|
|
|
|
def forward_native(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
return RMSNorm.forward_static(
|
|
input, self.epsilon, input.size(-1), self.model_dtype, weight
|
|
)
|
|
|
|
|
|
class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
|
def __init__(self, epsilon: float, enabled: bool | None = None):
|
|
if enabled is None:
|
|
enabled = RMSNorm.enabled()
|
|
|
|
super().__init__(enabled)
|
|
self.epsilon = epsilon
|
|
|
|
def inputs(self):
|
|
input = self.empty(5, 16) if self.enabled else self.empty_f32(5, 16)
|
|
weight = self.empty(16)
|
|
residual = self.empty(5, 16)
|
|
return [input, weight, residual]
|
|
|
|
def forward_custom(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
_, result, residual = auto_functionalized(
|
|
RMS_ADD_OP,
|
|
input=input,
|
|
residual=residual,
|
|
weight=weight,
|
|
epsilon=self.epsilon,
|
|
)
|
|
|
|
return result, residual
|
|
|
|
def forward_native(
|
|
self,
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return RMSNorm.forward_static(
|
|
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
|
)
|
|
|
|
|
|
class MatcherQuantFP8(MatcherCustomOp):
|
|
def __init__(self, quant_key: QuantKey, enabled: bool | None = None):
|
|
if enabled is None:
|
|
enabled = QuantFP8.enabled()
|
|
|
|
super().__init__(enabled)
|
|
self.quant_key = quant_key
|
|
assert quant_key in QUANT_OPS, f"unsupported quantization scheme {quant_key}"
|
|
self.QUANT_OP = QUANT_OPS[quant_key]
|
|
|
|
assert quant_key.dtype == current_platform.fp8_dtype(), (
|
|
"Only QuantFP8 supported by"
|
|
)
|
|
assert quant_key.scale2 is None
|
|
self.quant_fp8 = QuantFP8(quant_key.scale.static, quant_key.scale.group_shape)
|
|
|
|
def forward_custom(
|
|
self,
|
|
input: torch.Tensor,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
result = torch.empty(
|
|
input.shape, device=input.device, dtype=self.quant_key.dtype
|
|
)
|
|
|
|
if self.quant_key.scale.static:
|
|
assert scale is not None
|
|
_, result = auto_functionalized(
|
|
self.QUANT_OP, result=result, input=input, scale=scale
|
|
)
|
|
return result, scale
|
|
else:
|
|
assert scale is None
|
|
scale = self.make_scale(input)
|
|
_, result, scale = auto_functionalized(
|
|
self.QUANT_OP, result=result, input=input, scale=scale, scale_ub=None
|
|
)
|
|
return result, scale
|
|
|
|
def forward_native(
|
|
self,
|
|
input: torch.Tensor,
|
|
scale: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return self.quant_fp8(input, scale)
|
|
|
|
def make_scale(self, input: torch.Tensor):
|
|
normalized_group_shape = _normalize_quant_group_shape(
|
|
input, self.quant_key.scale.group_shape
|
|
)
|
|
scale_shape = (
|
|
input.shape[0] // normalized_group_shape[0],
|
|
input.shape[1] // normalized_group_shape[1],
|
|
)
|
|
|
|
return torch.empty(scale_shape, device=input.device, dtype=torch.float32)
|
|
|
|
def inputs(self) -> list[torch.Tensor]:
|
|
input = self.empty(5, 16)
|
|
if self.quant_key.scale.static:
|
|
return [input, self.empty_f32(1, 1)]
|
|
|
|
return [input]
|