mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 23:55:44 +08:00
[amd][gptoss] Perf gain because of block alignment (#28024)
Signed-off-by: Smit Kadvani <smit.kadvani@gmail.com> Co-authored-by: Smit Shaileshbhai Kadvani <kadvani@meta.com>
This commit is contained in:
parent
c0a4b95d64
commit
11fd69dd54
@ -43,6 +43,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
_can_support_mxfp4,
|
_can_support_mxfp4,
|
||||||
_swizzle_mxfp4,
|
_swizzle_mxfp4,
|
||||||
|
get_padding_alignment,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -282,10 +283,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
hidden_size = round_up(hidden_size, 128)
|
hidden_size = round_up(hidden_size, 128)
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
|
pad_align = get_padding_alignment()
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 256
|
intermediate_size_per_partition, pad_align
|
||||||
)
|
)
|
||||||
hidden_size = round_up(hidden_size, 256)
|
hidden_size = round_up(hidden_size, pad_align)
|
||||||
else:
|
else:
|
||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 64
|
intermediate_size_per_partition, 64
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -99,6 +100,14 @@ def _can_support_mxfp4(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_padding_alignment():
|
||||||
|
return (
|
||||||
|
256
|
||||||
|
if triton.runtime.driver.active.get_current_target().arch in ("gfx950",)
|
||||||
|
else 128
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _dequant_mxfp4(
|
def _dequant_mxfp4(
|
||||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user