Merge remote-tracking branch 'origin/main' into fix/gptq-rocm

This commit is contained in:
Andreas Karatzas 2025-12-15 21:57:21 +00:00
commit 036f6b8a1a
No known key found for this signature in database
GPG Key ID: 74A33CBB22F03519
14 changed files with 139 additions and 185 deletions

View File

@ -16,15 +16,15 @@ vLLM offers basic model inferencing and serving on Arm CPU platform, with suppor
# --8<-- [start:pre-built-wheels]
Pre-built vLLM wheels for Arm are available since version 0.11.2. These wheels contain pre-compiled C++ binaries.
Please replace `<version>` in the commands below with a specific version string (e.g., `0.11.2`).
```bash
uv pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
```
??? console "pip"
```bash
pip install --pre vllm==<version>+cpu --extra-index-url https://wheels.vllm.ai/<version>%2Bcpu/
pip install vllm==${VLLM_VERSION}+cpu --extra-index-url https://wheels.vllm.ai/${VLLM_VERSION}/cpu
```
The `uv` approach works for vLLM `v0.6.6` and later. A unique feature of `uv` is that packages in `--extra-index-url` have [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). If the latest public release is `v0.6.6.post1`, `uv`'s behavior allows installing a commit before `v0.6.6.post1` by specifying the `--extra-index-url`. In contrast, `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version.
@ -35,20 +35,28 @@ LLM inference is a fast-evolving field, and the latest code may contain bug fixe
* `https://wheels.vllm.ai/nightly/cpu/vllm`
To install from nightly index, copy the link address of the `*.whl` under this index to run, for example:
To install from nightly index, run:
```bash
uv pip install -U https://wheels.vllm.ai/c756fb678184b867ed94e5613a529198f1aee423/vllm-0.13.0rc2.dev11%2Bgc756fb678.cpu-cp38-abi3-manylinux_2_31_aarch64.whl # current nightly build (the filename will change!)
uv pip install vllm --extra-index-url https://wheels.vllm.ai/nightly/cpu
```
??? console "pip (there's a caveat)"
Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes).
If you insist on using `pip`, you have to specify the full URL (link address) of the wheel file (which can be obtained from https://wheels.vllm.ai/nightly/cpu/vllm).
```bash
pip install https://wheels.vllm.ai/4fa7ce46f31cbd97b4651694caf9991cc395a259/vllm-0.13.0rc2.dev104%2Bg4fa7ce46f.cpu-cp38-abi3-manylinux_2_35_aarch64.whl # current nightly build (the filename will change!)
```
**Install specific revisions**
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), specify the full commit hash in the index:
https://wheels.vllm.ai/${VLLM_COMMIT}/cpu/vllm .
Then, copy the link address of the `*.whl` under this index to run:
If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL:
```bash
uv pip install -U <wheel-url>
export VLLM_COMMIT=730bd35378bf2a5b56b6d3a45be28b3092d26519 # use full commit hash from the main branch
uv pip install vllm --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT}/cpu
```
# --8<-- [end:pre-built-wheels]
@ -103,10 +111,10 @@ Testing has been conducted on AWS Graviton3 instances for compatibility.
See [Using Docker](../../deployment/docker.md) for instructions on using the official Docker image.
Stable vLLM Docker images are being pre-built for Arm from version 0.12.0. Available image tags are here: [https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo](https://gallery.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo).
Please replace `<version>` in the command below with a specific version string (e.g., `0.12.0`).
```bash
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v<version>
export VLLM_VERSION=$(curl -s https://api.github.com/repos/vllm-project/vllm/releases/latest | jq -r .tag_name | sed 's/^v//')
docker pull public.ecr.aws/q9t5s3a7/vllm-arm64-cpu-release-repo:v${VLLM_VERSION}
```
You can also access the latest code with Docker images. These are not intended for production use and are meant for CI and testing only. They will expire after several days.

View File

@ -23,14 +23,6 @@ class TestParameterSweepItem:
{"compilation_config.use_inductor_graph_partition": True},
"--compilation-config.use_inductor_graph_partition=true",
),
(
{"compilation_config.use_inductor": False},
"--compilation-config.use_inductor=false",
),
(
{"compilation_config.use_inductor": True},
"--compilation-config.use_inductor=true",
),
],
)
def test_nested_boolean_params(self, input_dict, expected):

View File

@ -464,7 +464,10 @@ class MultiHeadAttention(nn.Module):
}
self.fa_version = None
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
if (
self.attn_backend == AttentionBackendEnum.FLASH_ATTN
and current_platform.is_cuda()
):
self.fa_version = get_flash_attn_version()
assert self._flash_attn_varlen_func is not None
self._flash_attn_varlen_func = functools.partial(

View File

@ -2,11 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import cache
from typing import cast, get_args
from typing import NamedTuple, cast, get_args
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
@ -18,6 +18,31 @@ from vllm.utils.import_utils import resolve_obj_by_qualname
logger = init_logger(__name__)
class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER
def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"use_mm_prefix={self.use_mm_prefix}, "
f"attn_type={self.attn_type})"
)
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
@ -43,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
@ -53,36 +77,25 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
attn_type=attn_type or AttentionType.DECODER,
)
return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)
@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(

View File

@ -8,7 +8,7 @@ from dataclasses import field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from pydantic import Field, TypeAdapter, field_validator
from pydantic import ConfigDict, Field, TypeAdapter, field_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
@ -96,7 +96,7 @@ class CUDAGraphMode(enum.Enum):
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class PassConfig:
"""Configuration for custom Inductor passes.
@ -251,7 +251,7 @@ class DynamicShapesType(str, enum.Enum):
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
@ -290,7 +290,7 @@ class DynamicShapesConfig:
@config
@dataclass
@dataclass(config=ConfigDict(extra="forbid"))
class CompilationConfig:
"""Configuration for compilation.

View File

@ -8,7 +8,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
@ -109,7 +109,7 @@ class ModelConfig:
"""Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks."""
tokenizer: SkipValidation[str] = None # type: ignore
tokenizer: str = Field(default=None)
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto"
@ -164,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: SkipValidation[int] = None # type: ignore
max_model_len: int = Field(default=None, gt=0)
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
@ -175,7 +175,7 @@ class ModelConfig:
- 25.6k -> 25,600"""
spec_target_max_model_len: int | None = None
"""Specify the maximum length for spec decoding draft models."""
quantization: SkipValidation[QuantizationMethods | None] = None
quantization: QuantizationMethods | str | None = None
"""Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
@ -597,6 +597,14 @@ class ModelConfig:
self._verify_cuda_graph()
self._verify_bnb_config()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
@field_validator("tokenizer_mode", mode="after")
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
return tokenizer_mode.lower()
@ -610,13 +618,14 @@ class ModelConfig:
@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
"""Called after __post_init__"""
if not isinstance(self.tokenizer, str):
raise ValueError(
f"tokenizer must be a string, got "
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
"Please provide a valid tokenizer path or HuggingFace model ID."
)
if not isinstance(self.max_model_len, int) or self.max_model_len <= 0:
if not isinstance(self.max_model_len, int):
raise ValueError(
f"max_model_len must be a positive integer, "
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "

View File

@ -6,7 +6,7 @@ from typing import Any
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
@ -1004,27 +1004,30 @@ def vllm_is_batch_invariant() -> bool:
return VLLM_BATCH_INVARIANT
def override_envs_for_invariance():
curr_attn_backend = envs.VLLM_ATTENTION_BACKEND
def override_envs_for_invariance(
attention_backend: AttentionBackendEnum | None,
):
supported_backends = [
"FLASH_ATTN", # best supported backend
"FLASHINFER",
"FLASH_ATTN_MLA",
"TRITON_MLA",
AttentionBackendEnum.FLASH_ATTN, # best supported backend
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
# Not yet supported MLA backends
# "FLASHMLA",
# "FLEX_ATTENTION", # IMA issue even if we disable batch invariance
# "FLASHINFER_MLA", https://github.com/vllm-project/vllm/pull/28967
# AttentionBackendEnum.FLASHMLA,
# AttentionBackendEnum.FLEX_ATTENTION, # IMA issue
# AttentionBackendEnum.FLASHINFER_MLA, # PR #28967
]
if curr_attn_backend not in supported_backends:
if attention_backend not in supported_backends:
supported_names = [b.name for b in supported_backends]
backend_name = attention_backend.name if attention_backend else None
error = (
"VLLM batch_invariant mode requires an attention backend in "
f"{supported_backends}, but got '{curr_attn_backend}'. "
"Please set the 'VLLM_ATTENTION_BACKEND' environment variable "
"to one of the supported backends before enabling batch_invariant."
f"{supported_names}, but got '{backend_name}'. "
"Please use --attention-backend or attention_config to set "
"one of the supported backends before enabling batch_invariant."
)
raise RuntimeError(error)
if os.environ["VLLM_ATTENTION_BACKEND"] != supported_backends[0]:
if attention_backend != supported_backends[0]:
warning = (
"You are using a decode-invariant form of batch invariance. "
"This will not be invariant between prefill and decode."
@ -1050,10 +1053,12 @@ def override_envs_for_invariance():
os.environ["VLLM_USE_AOT_COMPILE"] = "0"
def init_batch_invariance():
def init_batch_invariance(
attention_backend: AttentionBackendEnum | None,
):
# this will hit all the csrc overrides as well
if vllm_is_batch_invariant():
override_envs_for_invariance()
override_envs_for_invariance(attention_backend)
enable_batch_invariant_mode()
# Disable TF32 for batch invariance - it causes non-deterministic rounding

View File

@ -23,6 +23,7 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
@ -126,21 +127,13 @@ class CpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path()

View File

@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
@ -23,6 +22,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
@ -258,16 +258,8 @@ class CudaPlatformBase(Platform):
@classmethod
def get_valid_backends(
cls,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
@ -275,21 +267,15 @@ class CudaPlatformBase(Platform):
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability)
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
@ -304,37 +290,19 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER
device_capability = cls.get_device_capability()
assert device_capability is not None
attn_selector_config = attn_selector_config._replace(block_size=None)
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
@ -350,16 +318,8 @@ class CudaPlatformBase(Platform):
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
attn_selector_config=attn_selector_config,
)
reasons_str = (
"{"
@ -369,11 +329,7 @@ class CudaPlatformBase(Platform):
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."

View File

@ -18,8 +18,8 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -226,15 +226,7 @@ class Platform:
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -15,6 +15,7 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
logger = init_logger(__name__)
@ -190,21 +191,16 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
if use_sparse:
if kv_cache_dtype.startswith("fp8"):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype
if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
@ -214,7 +210,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla:
if attn_selector_config.use_mla:
if selected_backend is None:
selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA

View File

@ -16,6 +16,7 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
@ -57,17 +58,9 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

View File

@ -14,6 +14,7 @@ from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
@ -42,15 +43,7 @@ class XPUPlatform(Platform):
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
@ -60,7 +53,7 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
@ -71,7 +64,7 @@ class XPUPlatform(Platform):
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
f"with use_mla: {use_mla}"
f"with use_mla: {attn_selector_config.use_mla}"
)
logger.info("Using Flash Attention backend.")

View File

@ -931,10 +931,11 @@ def init_worker_distributed_environment(
backend: str = "nccl",
) -> None:
"""Initialize the distributed environment."""
attention_config = vllm_config.attention_config
parallel_config = vllm_config.parallel_config
from vllm.model_executor.layers.batch_invariant import init_batch_invariance
init_batch_invariance()
init_batch_invariance(attention_config.backend)
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://"