mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-22 05:48:00 +08:00
Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: Zhewen Li <zhewenli@meta.com>
This commit is contained in:
parent
0ee6416f67
commit
ae339b1a67
@ -30,6 +30,7 @@ from vllm.model_executor.parameter import (
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_gemm_nt,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp:
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
|
||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||
input_2d,
|
||||
group_size=self.act_quant_group_shape.col,
|
||||
use_ue8m0=True,
|
||||
)
|
||||
else:
|
||||
assert self.deepgemm_input_quant_op is not None
|
||||
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||
output = torch.empty(
|
||||
(q_input.shape[0], weight.shape[0]),
|
||||
dtype=torch.bfloat16,
|
||||
|
||||
@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
|
||||
|
||||
__all__ = [
|
||||
"calc_diff",
|
||||
"DeepGemmQuantScaleFMT",
|
||||
"fp8_gemm_nt",
|
||||
"m_grouped_fp8_gemm_nt_contiguous",
|
||||
"fp8_m_grouped_gemm_nt_masked",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user