mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-31 04:57:04 +08:00
Signed-off-by: zhewenli <zhewenli@meta.com> Signed-off-by: Zhewen Li <zhewenli@meta.com>
This commit is contained in:
parent
0ee6416f67
commit
ae339b1a67
@ -30,6 +30,7 @@ from vllm.model_executor.parameter import (
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.utils.deep_gemm import (
|
from vllm.utils.deep_gemm import (
|
||||||
|
DeepGemmQuantScaleFMT,
|
||||||
fp8_gemm_nt,
|
fp8_gemm_nt,
|
||||||
is_deep_gemm_e8m0_used,
|
is_deep_gemm_e8m0_used,
|
||||||
is_deep_gemm_supported,
|
is_deep_gemm_supported,
|
||||||
@ -268,12 +269,15 @@ class W8A8BlockFp8LinearOp:
|
|||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
weight_scale: torch.Tensor,
|
weight_scale: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert self.deepgemm_input_quant_op is not None
|
if DeepGemmQuantScaleFMT.from_oracle() == DeepGemmQuantScaleFMT.UE8M0:
|
||||||
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
q_input, input_scale = per_token_group_quant_fp8_packed_for_deepgemm(
|
||||||
input_2d,
|
input_2d,
|
||||||
group_size=self.act_quant_group_shape.col,
|
group_size=self.act_quant_group_shape.col,
|
||||||
use_ue8m0=True,
|
use_ue8m0=True,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
assert self.deepgemm_input_quant_op is not None
|
||||||
|
q_input, input_scale = self.deepgemm_input_quant_op(input_2d)
|
||||||
output = torch.empty(
|
output = torch.empty(
|
||||||
(q_input.shape[0], weight.shape[0]),
|
(q_input.shape[0], weight.shape[0]),
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
|
|||||||
@ -399,6 +399,7 @@ def should_use_deepgemm_for_fp8_linear_for_nk(
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"calc_diff",
|
"calc_diff",
|
||||||
|
"DeepGemmQuantScaleFMT",
|
||||||
"fp8_gemm_nt",
|
"fp8_gemm_nt",
|
||||||
"m_grouped_fp8_gemm_nt_contiguous",
|
"m_grouped_fp8_gemm_nt_contiguous",
|
||||||
"fp8_m_grouped_gemm_nt_masked",
|
"fp8_m_grouped_gemm_nt_masked",
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user