[Hardware][PPC64LE] Enable V1 for ppc64le and ARM (#20554)

Signed-off-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: Akash Kaothalkar <akash.kaothalkar@ibm.com>
Co-authored-by: Nikhil Gupta <nikhil.gupta2@arm.com>
This commit is contained in:
Akash kaothalkar 2025-07-09 08:30:41 +05:30 committed by GitHub
parent 977180c912
commit 6db31e7a27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 13 deletions

View File

@ -36,6 +36,7 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
@ -1096,7 +1097,6 @@ class EngineArgs:
If VLLM_USE_V1 is specified by the user but the VllmConfig
is incompatible, we raise an error.
"""
from vllm.platforms import current_platform
current_platform.pre_register_and_update()
device_config = DeviceConfig(
@ -1123,9 +1123,16 @@ class EngineArgs:
# Set default arguments for V0 or V1 Engine.
if use_v1:
self._set_default_args_v1(usage_context, model_config)
# Disable chunked prefill for POWER (ppc64le)/ARM CPUs in V1
if current_platform.is_cpu(
) and current_platform.get_cpu_architecture() in (
CpuArchEnum.POWERPC, CpuArchEnum.ARM):
logger.info(
"Chunked prefill is not supported for ARM and POWER CPUs; "
"disabling it for V1 backend.")
self.enable_chunked_prefill = False
else:
self._set_default_args_v0(model_config)
assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
@ -1242,7 +1249,6 @@ class EngineArgs:
if self.enable_chunked_prefill and self.pipeline_parallel_size > 1:
raise ValueError("Multi-Step Chunked-Prefill is not supported "
"for pipeline-parallel-size > 1")
from vllm.platforms import current_platform
if current_platform.is_cpu():
logger.warning("Multi-Step (--num-scheduler-steps > 1) is "
"currently not supported for CPUs and has been "
@ -1391,7 +1397,6 @@ class EngineArgs:
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from vllm.platforms import current_platform
if (current_platform.is_cuda()
and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8):
@ -1652,7 +1657,6 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try:
device_memory = current_platform.get_device_total_memory()
device_name = current_platform.get_device_name().lower()
@ -1755,7 +1759,6 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
from vllm.platforms import current_platform
current_platform.pre_register_and_update(parser)
return parser

View File

@ -271,5 +271,6 @@ class CpuPlatform(Platform):
"""Returns whether the current platform can use v1 by default for the
supplied model configuration.
"""
return cls.supports_v1(
model_config) and cls.get_cpu_architecture() == CpuArchEnum.X86
arch = cls.get_cpu_architecture()
return (cls.supports_v1(model_config) and arch
in (CpuArchEnum.X86, CpuArchEnum.POWERPC, CpuArchEnum.ARM))

View File

@ -316,7 +316,6 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
block_table: BlockTable) -> None:
self.runner = runner
self.block_table = block_table
# For reorder
self.reorder_prompt_req_index_list = np.empty(self.runner.max_num_reqs,
dtype=np.int64)
@ -401,11 +400,14 @@ class TorchSDPAMetadataBuilderV1(AttentionMetadataBuilder[TorchSDPAMetadata]):
num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
# to ensure inference when chunked_prefill is disabled
seq_lens=runner.seq_lens_cpu[:num_reqs].tolist(),
seq_lens_tensor=runner.
seq_lens_cpu[num_prompt_req:num_reqs], # decode
max_decode_seq_len=max_decode_seq_len, # decode
block_tables=block_table_tensor[num_prompt_req:num_reqs], # decode
chunked_prefill=True,
chunked_prefill=self.runner.scheduler_config.
chunked_prefill_enabled,
max_query_len=max_query_len,
max_kv_len=max_prefill_seq_len,
prefill_query_start_loc=runner.

View File

@ -11,7 +11,7 @@ from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms import CpuArchEnum, current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
@ -43,8 +43,12 @@ class CPUWorker(Worker):
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
self.local_omp_cpuid = "all"
if omp_cpuids == "auto":
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
)
if current_platform.get_cpu_architecture() == CpuArchEnum.POWERPC:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes_ppc64le())
else:
self.local_omp_cpuid = (
self.get_cpus_id_binding_based_on_numa_nodes())
else:
self.local_omp_cpuid = omp_cpuids.split("|")[self.rank]
@ -153,3 +157,57 @@ class CPUWorker(Worker):
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus
def get_cpus_id_binding_based_on_numa_nodes_ppc64le(self) -> str:
"""
Power (ppc64le) specific: Selects a subset of threads per core for
each NUMA node.This is robust to SMT mode (SMT-8, SMT-4, etc)
because the OS only exposes available threads.This maximizes
performance by avoiding oversubscription of logical CPUs on Power.
"""
def select_threads_per_power_core(node_cpu_ids):
return [cpu for cpu in node_cpu_ids if cpu % 8 < 4]
rank_to_cpus = self.local_omp_cpuid
world_size = self.vllm_config.parallel_config.world_size
libnuma_found = util.find_spec("numa") is not None
psutil_found = util.find_spec("psutil") is not None
if libnuma_found and psutil_found:
import psutil
from numa import info
cpus_allow_list = psutil.Process().cpu_affinity()
numa_size = info.get_num_configured_nodes()
node_to_cpus = []
for i in range(numa_size):
node_intersect = set(
info.node_to_cpus(i)).intersection(cpus_allow_list)
if bool(node_intersect):
node_to_cpus.append(sorted(list(node_intersect)))
if world_size > len(node_to_cpus):
logger.error(
"Auto thread-binding failed due to "
"world size: %d is larger than "
"allowed NUMA nodes number: %d."
"Please try to bind threads manually.", world_size,
len(node_to_cpus))
else:
node_cpus_this_rank = node_to_cpus[self.rank]
node_cpus_this_rank = select_threads_per_power_core(
node_cpus_this_rank)
cpu_count_per_numa = len(node_cpus_this_rank)
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
cpu_count_per_numa // 2)
end = cpu_count_per_numa - num_of_reserved_cpu
rank_to_cpus_list = node_cpus_this_rank[:end]
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
logger.info("ppc64le thread-binding list: %s", rank_to_cpus)
else:
logger.warning(
"Auto thread-binding is not supported due to "
"the lack of package numa and psutil,"
"fallback to no thread-binding. To get better performance,"
"please try to manually bind threads.")
return rank_to_cpus