mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-23 18:05:01 +08:00
[ROCm] Enable Triton ScaledMM fallback + kernel selection fix (#26668)
Signed-off-by: Shivam <shivampr.dev@gmail.com> Signed-off-by: Shivam <shivamprasad91@gmail.com>
This commit is contained in:
parent
02a5880394
commit
cd7740ac5c
@ -836,7 +836,7 @@ steps:
|
|||||||
- tests/models/multimodal
|
- tests/models/multimodal
|
||||||
no_gpu: true
|
no_gpu: true
|
||||||
commands:
|
commands:
|
||||||
- pip install git+https://github.com/TIGER-AI-Lab/Mantis.git
|
- "pip install git+https://github.com/TIGER-AI-Lab/Mantis.git || echo 'Mantis installation skipped (decord not available on CPU-only environment)'"
|
||||||
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
- pytest -v -s models/multimodal/processing --ignore models/multimodal/processing/test_tensor_schema.py
|
||||||
|
|
||||||
- label: Multi-Modal Processor Test
|
- label: Multi-Modal Processor Test
|
||||||
|
|||||||
@ -0,0 +1,91 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Tests for ScaledMM kernel selection logic (CPU-only)
|
||||||
|
|
||||||
|
Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
from abc import ABC
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
ScaledMMLinearLayerConfig,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||||
|
AiterScaledMMLinearKernel,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
|
||||||
|
CPUScaledMMLinearKernel,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
|
ScaledMMLinearKernel,
|
||||||
|
)
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_supported_is_abstract():
|
||||||
|
"""Test that is_supported() is properly defined as abstract."""
|
||||||
|
assert issubclass(ScaledMMLinearKernel, ABC)
|
||||||
|
assert hasattr(ScaledMMLinearKernel, "is_supported")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_kernel_implements_is_supported():
|
||||||
|
"""Test that CPUScaledMMLinearKernel implements is_supported() method."""
|
||||||
|
assert hasattr(CPUScaledMMLinearKernel, "is_supported"), (
|
||||||
|
"CPUScaledMMLinearKernel missing is_supported() method"
|
||||||
|
)
|
||||||
|
# Verify it's a classmethod by checking if it can be called with the class
|
||||||
|
# and by checking the method type
|
||||||
|
assert inspect.ismethod(CPUScaledMMLinearKernel.is_supported) or inspect.isfunction(
|
||||||
|
CPUScaledMMLinearKernel.is_supported
|
||||||
|
), "CPUScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||||
|
# Verify it can be called as a classmethod
|
||||||
|
result, reason = CPUScaledMMLinearKernel.is_supported()
|
||||||
|
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||||
|
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||||
|
|
||||||
|
|
||||||
|
def test_aiter_kernel_implements_is_supported():
|
||||||
|
"""Test that AiterScaledMMLinearKernel implements is_supported() method."""
|
||||||
|
assert hasattr(AiterScaledMMLinearKernel, "is_supported"), (
|
||||||
|
"AiterScaledMMLinearKernel missing is_supported() method"
|
||||||
|
)
|
||||||
|
# Verify it's a classmethod by checking if it can be called with the class
|
||||||
|
# and by checking the method type
|
||||||
|
assert inspect.ismethod(
|
||||||
|
AiterScaledMMLinearKernel.is_supported
|
||||||
|
) or inspect.isfunction(AiterScaledMMLinearKernel.is_supported), (
|
||||||
|
"AiterScaledMMLinearKernel.is_supported() should be a classmethod"
|
||||||
|
)
|
||||||
|
# Verify it can be called as a classmethod
|
||||||
|
# (will return False on CPU, which is expected)
|
||||||
|
result, reason = AiterScaledMMLinearKernel.is_supported()
|
||||||
|
assert isinstance(result, bool), "is_supported() should return a bool"
|
||||||
|
assert reason is None or isinstance(reason, str), "reason should be str or None"
|
||||||
|
# On CPU, it should return False with a reason about requiring ROCm
|
||||||
|
# This validates the method works correctly even on non-ROCm platforms
|
||||||
|
|
||||||
|
|
||||||
|
def test_cpu_kernel_accepts_all_configs():
|
||||||
|
"""Test that CPUScaledMMLinearKernel accepts all config combinations."""
|
||||||
|
configs = [
|
||||||
|
ScaledMMLinearLayerConfig(
|
||||||
|
is_channelwise=False,
|
||||||
|
is_static_input_scheme=True,
|
||||||
|
input_symmetric=True,
|
||||||
|
),
|
||||||
|
ScaledMMLinearLayerConfig(
|
||||||
|
is_channelwise=True,
|
||||||
|
is_static_input_scheme=False,
|
||||||
|
input_symmetric=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
can_impl, reason = CPUScaledMMLinearKernel.can_implement(config)
|
||||||
|
assert can_impl, (
|
||||||
|
f"CPUScaledMMLinearKernel should accept config {config}: {reason}"
|
||||||
|
)
|
||||||
@ -17,7 +17,9 @@ class ScaledMMLinearLayerConfig:
|
|||||||
class ScaledMMLinearKernel(ABC):
|
class ScaledMMLinearKernel(ABC):
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -35,6 +37,7 @@ class ScaledMMLinearKernel(ABC):
|
|||||||
azp_adj_param_name: str,
|
azp_adj_param_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert self.can_implement(c)
|
assert self.can_implement(c)
|
||||||
|
assert self.is_supported()
|
||||||
self.config = c
|
self.config = c
|
||||||
self.w_q_name = w_q_param_name
|
self.w_q_name = w_q_param_name
|
||||||
self.w_s_name = w_s_param_name
|
self.w_s_name = w_s_param_name
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from vllm.platforms import PlatformEnum, current_platform
|
|||||||
# in priority/performance order (when available)
|
# in priority/performance order (when available)
|
||||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
|
||||||
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
|
||||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||||
}
|
}
|
||||||
@ -55,41 +55,25 @@ def choose_scaled_mm_linear_kernel(
|
|||||||
type[ScaledMMLinearKernel]: Chosen kernel.
|
type[ScaledMMLinearKernel]: Chosen kernel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if compute_capability is None:
|
|
||||||
_cc = current_platform.get_device_capability()
|
|
||||||
if _cc is not None:
|
|
||||||
compute_capability = _cc[0] * 10 + _cc[1]
|
|
||||||
|
|
||||||
failure_reasons = []
|
failure_reasons = []
|
||||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||||
failure_reasons.append(
|
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
|
||||||
f" {kernel.__name__} disabled by environment variable"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# If the current platform uses compute_capability,
|
# If the current platform uses compute_capability,
|
||||||
# make sure the kernel supports the compute cability.
|
# make sure the kernel supports the compute cability.
|
||||||
if compute_capability is not None:
|
is_supported, reason = kernel.is_supported(compute_capability)
|
||||||
kernel_min_capability = kernel.get_min_capability()
|
if not is_supported:
|
||||||
if (
|
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||||
kernel_min_capability is not None
|
continue
|
||||||
and kernel_min_capability > compute_capability
|
|
||||||
):
|
can_implement, reason = kernel.can_implement(config)
|
||||||
failure_reasons.append(
|
if not can_implement:
|
||||||
f"{kernel.__name__} requires capability "
|
failure_reasons.append(f"{kernel.__name__}: {reason}")
|
||||||
f"{kernel_min_capability}, current compute capability "
|
|
||||||
f"is {compute_capability}"
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
can_implement, failure_reason = kernel.can_implement(config)
|
|
||||||
if can_implement:
|
|
||||||
return kernel
|
return kernel
|
||||||
else:
|
|
||||||
failure_reasons.append(
|
|
||||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Failed to find a kernel that can implement the "
|
"Failed to find a kernel that can implement the "
|
||||||
|
|||||||
@ -14,17 +14,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
|||||||
|
|
||||||
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
return 90
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
@classmethod
|
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
|
||||||
if not current_platform.is_rocm():
|
if not current_platform.is_rocm():
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||||
+ "currently supported on non-ROCm platform.",
|
+ "currently supported on non-ROCm platform.",
|
||||||
)
|
)
|
||||||
|
if compute_capability is None:
|
||||||
|
_cc = current_platform.get_device_capability()
|
||||||
|
if _cc is not None:
|
||||||
|
compute_capability = _cc.major * 10 + _cc.minor
|
||||||
|
if compute_capability is not None and compute_capability < 90:
|
||||||
|
return False, f"requires capability 90, got {compute_capability}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||||
@ -34,8 +38,8 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
"AiterScaledMMLinearKernel requires `aiter` which is not "
|
||||||
+ "installed on ROCm.",
|
+ "installed on ROCm.",
|
||||||
)
|
)
|
||||||
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
|
||||||
if not (rocm_aiter_ops.is_linear_enabled()):
|
if not rocm_aiter_ops.is_linear_enabled():
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
"AiterScaledMMLinearKernel is disabled. "
|
"AiterScaledMMLinearKernel is disabled. "
|
||||||
@ -44,6 +48,10 @@ class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not c.input_symmetric:
|
if not c.input_symmetric:
|
||||||
return (
|
return (
|
||||||
False,
|
False,
|
||||||
|
|||||||
@ -19,14 +19,15 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
|||||||
|
|
||||||
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
return 75
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_cpu():
|
||||||
|
return False, "Requires CPU."
|
||||||
|
return True, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cpu():
|
|
||||||
return False, "CPUScaledMM requires running on CPU."
|
|
||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|||||||
@ -16,14 +16,21 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
|||||||
|
|
||||||
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
return 75
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
return False, "Requires CUDA."
|
||||||
|
if compute_capability is None:
|
||||||
|
_cc = current_platform.get_device_capability()
|
||||||
|
if _cc is not None:
|
||||||
|
compute_capability = _cc.major * 10 + _cc.minor
|
||||||
|
if compute_capability is not None and compute_capability < 75:
|
||||||
|
return False, f"requires capability 75, got {compute_capability}"
|
||||||
|
return True, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if not current_platform.is_cuda():
|
|
||||||
return False, "CutlassScaledMM requires running on CUDA."
|
|
||||||
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
|||||||
@ -4,34 +4,53 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
|
||||||
|
triton_scaled_mm,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .cutlass import CutlassScaledMMLinearKernel
|
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
|
||||||
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
return 75
|
cls, compute_capability: int | None = None
|
||||||
|
) -> tuple[bool, str | None]:
|
||||||
|
if current_platform.is_cuda_alike():
|
||||||
|
return True, None
|
||||||
|
return False, "Requires ROCm or CUDA."
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
if current_platform.is_cpu():
|
|
||||||
return (
|
|
||||||
False,
|
|
||||||
"TritonScaledMMLinearKernel requires Triton which is not "
|
|
||||||
+ "currently supported on CPU.",
|
|
||||||
)
|
|
||||||
if not c.input_symmetric:
|
if not c.input_symmetric:
|
||||||
return (
|
return False, "Only symmetric input is supported."
|
||||||
False,
|
|
||||||
"TritonScaledMMLinearKernel only supports symmetric " + "quantization.",
|
|
||||||
)
|
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
super().process_weights_after_loading(layer)
|
weight = getattr(layer, self.w_q_name)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
self.w_q_name,
|
||||||
|
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.config.is_static_input_scheme:
|
||||||
|
input_scale = getattr(layer, self.i_s_name)
|
||||||
|
replace_parameter(
|
||||||
|
layer,
|
||||||
|
self.i_s_name,
|
||||||
|
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||||
|
)
|
||||||
|
setattr(layer, self.i_zp_name, None)
|
||||||
|
else:
|
||||||
|
setattr(layer, self.i_s_name, None)
|
||||||
|
setattr(layer, self.i_zp_name, None)
|
||||||
|
|
||||||
|
setattr(layer, self.azp_adj_name, None)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
@ -39,4 +58,14 @@ class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: torch.Tensor | None = None,
|
bias: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return super().apply_weights(layer, x, bias)
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||||
|
x.contiguous(), i_s, i_zp, symmetric=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||||
|
|
||||||
|
return triton_scaled_mm(
|
||||||
|
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||||
|
)
|
||||||
|
|||||||
@ -17,11 +17,12 @@ from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfi
|
|||||||
|
|
||||||
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def is_supported(
|
||||||
raise NotImplementedError(
|
cls, compute_capability: int | None = None
|
||||||
"TPU platform does have a concept of compute capability, "
|
) -> tuple[bool, str | None]:
|
||||||
"this method should not be called."
|
if not current_platform.is_tpu():
|
||||||
)
|
return False, "Requires TPU."
|
||||||
|
return True, None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user