[CPU] update torch 2.8 and fix missing fields in TorchSDPAMetadata (#25652)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang 2025-09-25 21:29:11 +08:00 committed by GitHub
parent 69a8c8e99a
commit eb32335e35
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 59 additions and 53 deletions

View File

@ -58,11 +58,8 @@ function cpu_tests() {
# pytest -x -v -s tests/kernels/attention/test_cache.py -m cpu_model
# pytest -x -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
# Note: disable Bart until supports V1
pytest -x -v -s tests/models/language/generation -m cpu_model \
--ignore=tests/models/language/generation/test_bart.py
VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model \
--ignore=tests/models/language/generation/test_bart.py
pytest -x -v -s tests/models/language/generation -m cpu_model
VLLM_CPU_SGL_KERNEL=1 pytest -x -v -s tests/models/language/generation -m cpu_model
pytest -x -v -s tests/models/language/pooling -m cpu_model
pytest -x -v -s tests/models/multimodal/generation \

View File

@ -114,9 +114,6 @@ WORKDIR /workspace/vllm
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
cp requirements/test.in requirements/cpu-test.in && \
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
RUN --mount=type=cache,target=/root/.cache/uv \

View File

@ -1,12 +1,10 @@
# Temporarily used for x86 CPU backend to avoid performance regression of torch>2.6.0+cpu,
# see https://github.com/pytorch/pytorch/pull/151218
cmake>=3.26.1
ninja
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
torch==2.8.0+cpu; platform_machine == "x86_64"
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64" or platform_system == "Darwin"
wheel
jinja2>=3.1.6

View File

@ -8,7 +8,7 @@ numba == 0.61.2; python_version > '3.9' and platform_machine != "s390x"
packaging>=24.2
setuptools>=77.0.3,<80.0.0
--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.6.0+cpu; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
torch==2.8.0+cpu; platform_machine == "x86_64"
torch==2.8.0; platform_system == "Darwin"
torch==2.8.0; platform_machine == "ppc64le" or platform_machine == "aarch64"
@ -23,7 +23,7 @@ datasets # for benchmark scripts
# Intel Extension for PyTorch, only for x86_64 CPUs
intel-openmp==2024.2.1; platform_machine == "x86_64"
intel_extension_for_pytorch==2.6.0; platform_machine == "x86_64" # torch>2.6.0+cpu has performance regression on x86 platform, see https://github.com/pytorch/pytorch/pull/151218
intel_extension_for_pytorch==2.8.0; platform_machine == "x86_64"
triton==3.2.0; platform_machine == "x86_64" # Triton is required for torch 2.6+cpu, as it is imported in torch.compile.
# Use this to gather CPU info and optimize based on ARM Neoverse cores

View File

@ -85,6 +85,19 @@ class TorchSDPABackend(AttentionBackend):
@dataclass
class TorchSDPAMetadata(AttentionMetadata):
"""Attention metadata for prefill and decode batched together."""
# Total number of prefill requests.
num_prefills: int
# Number of prefill tokens.
num_prefill_tokens: int
# Number of decode tokens. Note that it is equivalent to the number of
# decode requests.
num_decode_tokens: int
# (num_tokens,). The indices of the token slots that input tokens will be
# stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size
# is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
@ -420,7 +433,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_prompt_req], # prefill
query_start_loc=query_start_loc_cpu[:num_reqs +
1], # for logits index
enable_kv_scales_calculation=False,
)
return attn_metadata

View File

@ -68,6 +68,8 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_cpu():
self.forward = self.forward_cpu
else:
self.forward = self.forward_native
@ -119,6 +121,45 @@ class TopKTopPSampler(nn.Module):
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def forward_cpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
elif self.logprobs_mode == "processed_logprobs":
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
# Note: this is a workaround for
# https://github.com/pytorch/pytorch/pull/151218
@torch.compile(dynamic=True)
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
return probs.div(q).argmax(dim=-1).view(-1)
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
def apply_top_k_top_p(
logits: torch.Tensor,

View File

@ -8,18 +8,13 @@ import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import CpuArchEnum, current_platform
from vllm.platforms.cpu import CpuPlatform, LogicalCPUInfo
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.cpu_model_runner import CPUModelRunner
from vllm.v1.worker.gpu_worker import (Worker,
init_worker_distributed_environment)
from vllm.v1.worker.utils import is_residual_scattered_for_sp
logger = init_logger(__name__)
@ -102,40 +97,6 @@ class CPUWorker(Worker):
set_random_seed(self.model_config.seed)
self.model_runner.warming_up_model()
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
num_input_tokens = self.model_runner._get_num_input_tokens(
num_scheduled_tokens)
all_gather_tensors = {
"residual":
not is_residual_scattered_for_sp(self.vllm_config,
num_input_tokens)
}
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
if not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(
output.tensors,
all_gather_group=get_tp_group(),
all_gather_tensors=all_gather_tensors)
return None
assert isinstance(output, ModelRunnerOutput)
return output if self.is_driver_worker else None
def _get_autobind_cpu_ids(
self, cpu_selector: Callable[[list[LogicalCPUInfo]],
list[LogicalCPUInfo]]