mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 15:56:16 +08:00
[ROCm] Enable Triton ScaledMM fallback + kernel selection fix (#26668)
Signed-off-by: Shivam <shivampr.dev@gmail.com> Signed-off-by: Shivam <shivamprasad91@gmail.com>
This commit is contained in:
parent
02a5880394
commit
cd7740ac5c
@ -836,7 +836,7 @@ steps:
|
||||
- tests/models/multimodal
|
||||
no_gpu: true
|
||||
commands:
|
||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
||||
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
|
||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||
|
||||
- label: Multi-Modal Processor Test
|
||||
|
||||
@ -0,0 +1,91 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for ScaledMM kernel selection logic (CPU-only)
|
||||
|
||||
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from abc import ABC
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||
AiterScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||
CPUScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||
ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def test_is_supported_is_abstract():
|
||||
"""Test that is_supported() is properly defined as abstract."""
|
||||
assert issubclass(ScaledMMLinearKernel, ABC)
|
||||
assert hasattr(ScaledMMLinearKernel, "is_supported")
|
||||
|
||||
|
||||
def test_cpu_kernel_implements_is_supported():
|
||||
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||
CPUScaledMMLinearKernel.is_supported
|
||||
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
# Verify it can be called as a classmethod
|
||||
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
|
||||
|
||||
def test_aiter_kernel_implements_is_supported():
|
||||
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||
)
|
||||
# Verify it's a classmethod by checking if it can be called with the class
|
||||
# and by checking the method type
|
||||
assert inspect.ismethod(
|
||||
AiterScaledMMLinearKernel.is_supported
|
||||
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||
)
|
||||
# Verify it can be called as a classmethod
|
||||
# (will return False on CPU, which is expected)
|
||||
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||
# On CPU, it should return False with a reason about requiring ROCm
|
||||
# This validates the method works correctly even on non-ROCm platforms
|
||||
|
||||
|
||||
def test_cpu_kernel_accepts_all_configs():
|
||||
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||
configs = [
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=False,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=True,
|
||||
),
|
||||
ScaledMMLinearLayerConfig(
|
||||
is_channelwise=True,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=False,
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||
assert can_impl, (
|
||||
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||
)
|
||||
@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
|
||||
class ScaledMMLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
|
||||
azp_adj_param_name: str,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
assert self.is_supported()
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
|
||||
@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||
}
|
||||
@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
|
||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
|
||||
continue
|
||||
|
||||
# If the current platform uses compute_capability,
|
||||
# make sure the kernel supports the compute cability.
|
||||
if compute_capability is not None:
|
||||
kernel_min_capability = kernel.get_min_capability()
|
||||
if (
|
||||
kernel_min_capability is not None
|
||||
and kernel_min_capability > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel_min_capability}, current compute capability "
|
||||
f"is {compute_capability}"
|
||||
)
|
||||
is_supported, reason = kernel.is_supported(compute_capability)
|
||||
if not is_supported:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
|
||||
can_implement, reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
|
||||
@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
|
||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "currently supported on non-ROCm platform.",
|
||||
)
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 90:
|
||||
return False, f"requires capability 90, got {compute_capability}"
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||
+ "installed on ROCm.",
|
||||
)
|
||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||
if not (rocm_aiter_ops.is_linear_enabled()):
|
||||
|
||||
if not rocm_aiter_ops.is_linear_enabled():
|
||||
return (
|
||||
False,
|
||||
"AiterScaledMMLinearKernel is disabled. "
|
||||
@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return (
|
||||
False,
|
||||
|
||||
@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Requires CPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUScaledMM requires running on CPU."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Requires CUDA."
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc.major * 10 + _cc.minor
|
||||
if compute_capability is not None and compute_capability < 75:
|
||||
return False, f"requires capability 75, got {compute_capability}"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CutlassScaledMM requires running on CUDA."
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@ -4,34 +4,53 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
|
||||
triton_scaled_mm,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cuda_alike():
|
||||
return True, None
|
||||
return False, "Requires ROCm or CUDA."
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cpu():
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel requires Triton which is not "
|
||||
+ "currently supported on CPU.",
|
||||
)
|
||||
if not c.input_symmetric:
|
||||
return (
|
||||
False,
|
||||
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
|
||||
)
|
||||
return False, "Only symmetric input is supported."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
super().process_weights_after_loading(layer)
|
||||
weight = getattr(layer, self.w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, self.i_s_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
self.i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
else:
|
||||
setattr(layer, self.i_s_name, None)
|
||||
setattr(layer, self.i_zp_name, None)
|
||||
|
||||
setattr(layer, self.azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return super().apply_weights(layer, x, bias)
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
)
|
||||
|
||||
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||
|
||||
return triton_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
|
||||
@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
||||
|
||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"TPU platform does have a concept of compute capability, "
|
||||
"this method should not be called."
|
||||
)
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_tpu():
|
||||
return False, "Requires TPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user