youkaichao ad34c0df0f
[core] platform agnostic executor via collective_rpc (#11256)
Signed-off-by: youkaichao <youkaichao@gmail.com>
2025-01-15 13:45:21 +08:00

93 lines
3.3 KiB
Python

from typing import TYPE_CHECKING, Optional
import torch
from vllm.logger import init_logger
from .interface import Platform, PlatformEnum, _Backend
if TYPE_CHECKING:
from vllm.config import VllmConfig
else:
VllmConfig = None
logger = init_logger(__name__)
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
device_name: str = "tpu"
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
supported_quantization: list[str] = [
"tpu_int8", "compressed-tensors", "compressed_tensors"
]
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
@classmethod
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
return True
@classmethod
def inference_mode(cls):
return torch.no_grad()
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
cache_config = vllm_config.cache_config
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16
compilation_config = vllm_config.compilation_config
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."
if compilation_config.backend == "":
compilation_config.backend = "openxla"
assert vllm_config.speculative_config is None, \
"TPU does not support speculative decoding"
assert not vllm_config.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not vllm_config.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if vllm_config.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.", vllm_config.model_config.dtype)
vllm_config.model_config.dtype = torch.bfloat16
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if scheduler_config.is_multi_step:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker"