From d5cab6f65c320c383fcd1521431002970715729f Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Thu, 11 Dec 2025 19:59:02 +0000 Subject: [PATCH 01/11] Remove tpu_inference fall back logic Signed-off-by: Wei-Yu Lin --- .../device_communicators/tpu_communicator.py | 14 - .../model_loader/default_loader.py | 13 - vllm/platforms/tpu.py | 253 +------------- vllm/v1/attention/backends/pallas.py | 65 +--- vllm/v1/worker/tpu_worker.py | 310 ------------------ 5 files changed, 2 insertions(+), 653 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index fa99078e9ff0d..9581a3dbc7b74 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -19,20 +19,6 @@ USE_RAY = parallel_config = ( logger = init_logger(__name__) -if not USE_TPU_INFERENCE: - logger.info("tpu_inference not found, using vLLM's TpuCommunicator") - if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups, - ) - - if USE_RAY: - from vllm.v1.executor import ray_utils - class TpuCommunicator(DeviceCommunicatorBase): def __init__( diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 88c6d1e27e39c..4d85f8e3b478c 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -244,19 +244,6 @@ class DefaultModelLoader(BaseModelLoader): if current_platform.is_tpu(): from vllm.platforms.tpu import USE_TPU_INFERENCE - if not USE_TPU_INFERENCE: - # In PyTorch XLA, we should call `torch_xla.sync` - # frequently so that not too many ops are accumulated - # in the XLA program. - import torch_xla - - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - torch_xla.sync(wait=False) - - weights_iterator = _xla_weights_iterator(weights_iterator) - if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 7c479bf2b6a0e..f7a11d2c557c4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -31,257 +31,6 @@ else: logger = init_logger(__name__) -USE_TPU_INFERENCE = False - - -class TpuPlatform(Platform): - _enum = PlatformEnum.TPU - device_name: str = "tpu" - device_type: str = "tpu" - dispatch_key: str = "XLA" - ray_device_key: str = "TPU" - dist_backend: str = "gloo" - device_control_env_var: str = "TPU_VISIBLE_CHIPS" - simple_compile_backend: str = "openxla" - - supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"] - - additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] - - @classmethod - def import_kernels(cls) -> None: - # Do not import vllm._C - with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - - @classmethod - def get_attn_backend_cls( - cls, - selected_backend: "AttentionBackendEnum", - attn_selector_config: "AttentionSelectorConfig", - ) -> str: - if attn_selector_config.use_sparse: - raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != AttentionBackendEnum.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - - logger.info("Using Pallas V1 backend.") - return AttentionBackendEnum.PALLAS.get_path() - - @classmethod - def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: - return [ - AttentionBackendEnum.PALLAS, - ] - - @classmethod - def get_vit_attn_backend( - cls, - head_size: int, - dtype: torch.dtype, - backend: Optional["AttentionBackendEnum"] = None, - ) -> "AttentionBackendEnum": - if backend is not None: - assert backend in cls.get_supported_vit_attn_backends(), ( - f"Backend {backend} is not supported for vit attention" - f"Supported backends are: {cls.get_supported_vit_attn_backends()}." - ) - logger.info_once(f"Using backend {backend} for vit attention.") - return backend - - logger.info_once( - f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention." - ) - return AttentionBackendEnum.PALLAS - - @classmethod - def set_device(cls, device: torch.device) -> None: - """ - Set the device for the current platform. - """ - torch.tpu.set_device(device) - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - chip_type, _ = device.get_local_chips() - return f"TPU {chip_type.name}" - - @classmethod - def get_device_total_memory(cls, device_id: int = 0) -> int: - raise NotImplementedError - - @classmethod - def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" - - @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: - return torch.finfo(dtype).min, torch.finfo(dtype).max - - @classmethod - def can_update_inplace(cls): - return False - - @classmethod - def get_lora_vocab_padding_size(cls) -> int: - return 1 - - @classmethod - def inference_mode(cls): - return torch.no_grad() - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationMode, CUDAGraphMode - - cache_config = vllm_config.cache_config - # For v0, the default block size is 16. - if cache_config and cache_config.block_size is None: - cache_config.block_size = cast(BlockSize, 16) - compilation_config = vllm_config.compilation_config - - # TPU only supports DYNAMO_TRACE_ONCE compilation mode - if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE: - logger.info( - "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\ - disabling cudagraph." - ) - compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE - - if ( - compilation_config.cudagraph_mode is None - or compilation_config.cudagraph_mode.max_cudagraph_mode() - != CUDAGraphMode.NONE - ): - logger.info( - "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs." - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - if compilation_config.backend == "": - compilation_config.backend = "openxla" - - assert vllm_config.speculative_config is None, ( - "TPU does not support speculative decoding" - ) - - model_config = vllm_config.model_config - if model_config is not None and model_config.dtype in ( - torch.float16, - torch.float32, - ): - logger.warning( - "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", - model_config.dtype, - ) - model_config.dtype = torch.bfloat16 - - from vllm.v1.attention.backends.pallas import PallasAttentionBackend - - cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment] - - parallel_config = vllm_config.parallel_config - scheduler_config = vllm_config.scheduler_config - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" - - assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend" - ) - - if ( - scheduler_config.is_multimodal_model - and not scheduler_config.disable_chunked_mm_input - ): - logger.warning( - "TPU does not support running Multimodal models" - " without setting `--disable_chunked_mm_input`. " - "Forcing --disable_chunked_mm_input." - ) - scheduler_config.disable_chunked_mm_input = True - - if model_config and model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled." - ) - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.model_config.max_model_len, - vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, - ) - - @classmethod - def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on TPU.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa - - @classmethod - def validate_request( - cls, - prompt: PromptType, - params: ParamsType, - processed_inputs: ProcessorInputs, - ) -> None: - """Raises if this request is unsupported on this platform""" - from vllm.sampling_params import SamplingParams, SamplingType - - if ( - isinstance(params, SamplingParams) - and params.sampling_type == SamplingType.RANDOM_SEED - ): - raise ValueError("Torch XLA does not support per-request seed.") - - @classmethod - @torch.compile(backend="openxla") - def insert_blocks_to_device( - cls, - src_cache: torch.Tensor, - dst_cache: torch.Tensor, - src_block_indices: torch.Tensor, - dst_block_indices: torch.Tensor, - ) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) - dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device) - - @classmethod - @torch.compile(backend="openxla") - def swap_out_blocks_to_host( - cls, - src_cache: torch.Tensor, - dst_cache: torch.Tensor, - src_block_indices: torch.Tensor, - dst_block_indices: torch.Tensor, - ) -> None: - """tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) - dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() - - @classmethod - def use_sync_weight_loader(cls) -> bool: - return True - - @classmethod - def check_max_model_len(cls, max_model_len: int) -> int: - """ - Check max_model_len for the current platform. - """ - logger.warning( - "--max-model-len is not specified, " - "it's currently using model's default length %d, " - "which might be too large." - "Please input with --max-model-len based on your " - "request input length and output length, to avoid " - "unnecessary degradation.", - max_model_len, - ) - return max_model_len - try: from tpu_inference.platforms import ( @@ -291,5 +40,5 @@ try: TpuPlatform = TpuInferencePlatform # type: ignore USE_TPU_INFERENCE = True except ImportError: - logger.info("tpu_inference not found, using vLLM's TpuPlatform") + logger.error("tpu_inference not found, please install tpu_inference to run vllm on TPU") pass diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 525026bac5a7e..5f7f8e81a24c4 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -33,70 +33,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { "uint8": torch.uint8, } -try: - import tpu_inference # noqa: F401 -except ImportError: - # Lazy import torch_xla - import torch_xla.core.xla_builder as xb - import torch_xla.experimental.custom_kernel # noqa: F401 - from torch.library import impl - from torch_xla._internal.jax_workarounds import requires_jax - from torch_xla.experimental.custom_kernel import XLA_LIB - - @requires_jax - def kv_cache_update_op_impl( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - - new_kv_cache = xb.call_jax( - kv_cache_update, - (kv, slot_mapping, kv_cache, num_kv_update_slices), - {"page_size": page_size, "num_slices_per_block": num_slices_per_block}, - ) - return new_kv_cache - - XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping," - "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," - "int num_slices_per_block)" - "-> Tensor", - ) - - @impl(XLA_LIB, "kv_cache_update_op", "XLA") - def kv_cache_update_op_xla( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl( - kv, - slot_mapping, - kv_cache, - num_kv_update_slices, - page_size, - num_slices_per_block, - ) - return new_kv_cache - - @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") - def kv_cache_update_op_non_xla( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ) -> torch.Tensor: - return kv_cache +import tpu_inference # noqa: F401 class PallasAttentionBackend(AttentionBackend): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5f6136b178b46..b50def0e17de4 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -36,316 +36,6 @@ logger = init_logger(__name__) _R = TypeVar("_R") -if not USE_TPU_INFERENCE: - logger.info("tpu_inference not found, using vLLM's TPUWorker.") - import torch_xla.core.xla_model as xm - import torch_xla.debug.profiler as xp - import torch_xla.runtime as xr - - from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT - from vllm.v1.worker.tpu_model_runner import TPUModelRunner - - -class TPUWorker: - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ): - self.is_driver_worker = is_driver_worker - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.use_spmd = envs.VLLM_XLA_USE_SPMD - self.original_parallel_config = None - if self.use_spmd: - # Under SPMD mode, distributed env is initialized as if there is - # only one worker/device. - self.original_parallel_config = self.parallel_config - self.parallel_config.tensor_parallel_size = 1 - self.parallel_config.pipeline_parallel_size = 1 - self.parallel_config.world_size = 1 - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.cache_config.cache_dtype == "auto": - self.cache_dtype = self.model_config.dtype - else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() - - # Delay profiler initialization to the start of the profiling. - # This is because in vLLM V1, MP runtime is initialized before the - # TPU Worker is initialized. The profiler server needs to start after - # MP runtime is initialized. - self.profiler = None - self.profile_dir = None - if vllm_config.profiler_config.profiler == "torch" and self.rank < 1: - # For TPU, we can only have 1 active profiler session for 1 profiler - # server. So we only profile on rank0. - self.profile_dir = vllm_config.profiler_config.torch_profiler_dir - logger.info( - "Profiling enabled. Traces will be saved to: %s", self.profile_dir - ) - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - def init_device(self): - os.environ["PJRT_DEVICE"] = "TPU" - # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D - # ring, the xla tpu compiler flag - # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to - # fix this. It will be removed after the bug in XLA compiler is fixed. - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") - + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" - " --xla_jf_conv_input_fusion=False" - ) - # --xla_jf_conv_input_fusion=False is used to improve the perf of - # quantized matmul. - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # Initialize the distributed environment. - self._init_tpu_worker_distributed_environment( - self.vllm_config, self.rank, self.distributed_init_method, self.local_rank - ) - - # Device initialization should happen after initializing - # the distributed runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # TODO (NickLucche) On gsm we compile 80+ graphs. - # Re-evaluate limit, with MM we may get close to this limit. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. - # Consequently, changes in optimization flags, which affect compilation - # results, don't change the cache key. This can result in the wrong - # compilation being used. To prevent this, disabling the XLA compilation - # cache during development is recommended.We can disable it by - # `export VLLM_XLA_CACHE_PATH=` - if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join( - envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}" - ) - xr.initialize_cache(per_rank_path, readonly=False) - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner( - self.vllm_config, self.device, self.original_parallel_config - ) - - if rank == 0: - # If usage stat is enabled, collect relevant info. - report_usage_stats(self.vllm_config) - - def determine_available_memory(self) -> int: - kv_caches: dict[str, torch.Tensor] = {} - kv_cache_spec = self.model_runner.get_kv_cache_spec() - for layer_name, layer_spec in kv_cache_spec.items(): - if isinstance(layer_spec, AttentionSpec): - dtype = layer_spec.dtype - - # Use an empty tensor instead of `None` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device) - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'" - ) - - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches, - ) - - # `max_num_tokens >= max_num_batched_tokens` due to padding. - with self.model_runner.maybe_setup_dummy_loras(self.lora_config): - self.model_runner.profile_run(self.model_runner.max_num_tokens) - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # During the profiling run, the model runs without KV cache. After - # the profiling run, the model always runs with KV cache. Here we clear - # the dynamo cache and cached bytecode to ensure the model always has - # one compiled bytecode. Having one FX graph/cached bytecode per - # compiled model is required for `support_torch_compile` decorator to - # skip dynamo guard. - with set_current_vllm_config(self.vllm_config): - self.model_runner.reset_dynamo_cache() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - if self.use_spmd: - # This is a workaround for the TPU SPMD mode. The get_memory_info - # API doesn't work with SPMD mode in PyTorch/XLA. - # TODO: use xm.get_memory_info for SPMD once it's supported in - # PyTorch/XLA. - import tpu_info - - chip_type, _ = tpu_info.device.get_local_chips() - device_usage = tpu_info.metrics.get_chip_usage(chip_type) - total_memory_size = device_usage[0].total_memory - current_mem = device_usage[0].memory_usage - else: - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - current_mem = m["bytes_used"] - # Ideally we would use profiled = m["peak_bytes_used"] to - # get weights + activations. But there is memory used during - # compilation / weight loading that impacts the peak and - # there is no way to reset peak memory in XLA, So we - # use the heuristic of 2% of weights. - profiled = current_mem * 1.02 - - # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int( - total_memory_size * self.cache_config.gpu_memory_utilization - ) - tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - head_size = self.model_config.get_head_size() - if head_size > 0: - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - if padded_head_size != head_size: - logger.warning_once("head size is padded to %d", padded_head_size) - # We adjust the usable memory size for the KV cache to prevent OOM - # errors, even after padding the head_size. - tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size - return int(tpu_kv_cache_bytes) - - def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output) - - def execute_model( - self, scheduler_output: "SchedulerOutput" - ) -> ModelRunnerOutput | None: - return self.model_runner.execute_model(scheduler_output) - - def profile(self, is_start: bool = True): - if self.rank < 1: - if self.profile_dir is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - if self.profiler is None: - self.profiler = xp.start_server(9012) - xp.start_trace(self.profile_dir) - else: - xp.stop_trace() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def load_model(self) -> None: - self.model_runner.load_model() - - def update_config(self, overrides: dict[str, Any]) -> None: - self.model_runner.update_config(overrides) - - def reload_weights(self) -> None: - self.model_runner.reload_weights() - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def reset_mm_cache(self) -> None: - self.model_runner.reset_mm_cache() - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - return self.model_runner.get_supported_tasks() - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - return self.model_runner.get_kv_cache_spec() - - def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: - """Allocate GPU KV cache with the specified kv_cache_config.""" - self.model_runner.initialize_kv_cache(kv_cache_config) - - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - - def _init_tpu_worker_distributed_environment( - self, - vllm_config: VllmConfig, - rank: int, - distributed_init_method: str | None = None, - local_rank: int = -1, - ) -> None: - """Initialize the distributed environment.""" - if self.use_spmd: - xr.use_spmd() - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - parallel_config = vllm_config.parallel_config - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - local_rank=local_rank, - distributed_init_method=distributed_init_method or "env://", - backend=current_platform.dist_backend, - ) - ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size - ) - - ensure_kv_transfer_initialized(vllm_config) - - def shutdown(self) -> None: - self.model_runner.ensure_kv_transfer_shutdown() - - def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: - """Apply a function on the model inside this worker.""" - return fn(self.get_model()) - - if USE_TPU_INFERENCE: from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker From bf1343cc1d258d5df993e022da480c62bf9244f5 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Fri, 12 Dec 2025 22:08:04 +0000 Subject: [PATCH 02/11] Remove torch_xla related code path excluding test files Signed-off-by: Wei-Yu Lin --- .../device_communicators/tpu_communicator.py | 85 - vllm/distributed/tpu_distributed_utils.py | 188 -- vllm/lora/ops/xla_ops/__init__.py | 6 - vllm/lora/ops/xla_ops/lora_ops.py | 141 -- vllm/lora/punica_wrapper/punica_tpu.py | 358 --- vllm/model_executor/layers/fused_moe/layer.py | 5 - .../layers/fused_moe/moe_pallas.py | 83 - .../fused_moe/unquantized_fused_moe_method.py | 52 +- .../layers/quantization/__init__.py | 2 - .../kernels/scaled_mm/__init__.py | 4 - .../quantization/kernels/scaled_mm/xla.py | 106 - .../layers/quantization/tpu_int8.py | 139 -- vllm/model_executor/model_loader/tpu.py | 118 - vllm/usage/usage_lib.py | 18 +- vllm/v1/attention/backends/pallas.py | 59 +- vllm/v1/worker/tpu_model_runner.py | 2191 ----------------- 16 files changed, 3 insertions(+), 3552 deletions(-) delete mode 100644 vllm/distributed/device_communicators/tpu_communicator.py delete mode 100644 vllm/distributed/tpu_distributed_utils.py delete mode 100644 vllm/lora/ops/xla_ops/__init__.py delete mode 100644 vllm/lora/ops/xla_ops/lora_ops.py delete mode 100644 vllm/lora/punica_wrapper/punica_tpu.py delete mode 100644 vllm/model_executor/layers/fused_moe/moe_pallas.py delete mode 100644 vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py delete mode 100644 vllm/model_executor/layers/quantization/tpu_int8.py delete mode 100644 vllm/model_executor/model_loader/tpu.py delete mode 100644 vllm/v1/worker/tpu_model_runner.py diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py deleted file mode 100644 index 9581a3dbc7b74..0000000000000 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ /dev/null @@ -1,85 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import os - -import torch -from torch.distributed import ProcessGroup - -from vllm.config import get_current_vllm_config -from vllm.logger import init_logger -from vllm.platforms import current_platform -from vllm.platforms.tpu import USE_TPU_INFERENCE - -from .base_device_communicator import DeviceCommunicatorBase - -USE_RAY = parallel_config = ( - get_current_vllm_config().parallel_config.distributed_executor_backend == "ray" -) - -logger = init_logger(__name__) - - -class TpuCommunicator(DeviceCommunicatorBase): - def __init__( - self, - cpu_group: ProcessGroup, - device: torch.device | None = None, - device_group: ProcessGroup | None = None, - unique_name: str = "", - ): - super().__init__(cpu_group, device, device_group, unique_name) - - # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node - # must be used together. Therefore, the local rank and world size can - # be simply calculated as follows. - global_rank = self.global_rank - global_world_size = self.global_world_size - - if USE_RAY: - logger.info("TpuCommunicator initialized with RAY") - # Calculate how many TPU nodes are in the current deployment. This - # is the Ray placement group if it is deployed with Ray. Default - # to the number of TPU nodes in the Ray cluster. The number of TPU - # nodes is computed by the total number of TPUs divided by the - # number of TPU accelerators per node, to account for clusters - # with both CPUs and TPUs. - num_nodes = ray_utils.get_num_tpu_nodes() - num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() - if num_nodes_in_pg > 0: - num_nodes = num_nodes_in_pg - - local_world_size = global_world_size // num_nodes - local_rank = global_rank % local_world_size - else: - logger.info("TpuCommunicator initialized with MP") - # Sanity: Verify we run on a single host - num_hosts = torch_xla.tpu.num_tpu_workers() - assert num_hosts == 1 - - # Get the current number of TPUs (we have locally) - local_world_size = torch_xla.tpu.num_available_chips() - - # Get current rank - local_rank = global_rank % local_world_size - - # Ensure environment variables are set for multihost deployments. - # On GKE, this is needed for libtpu and TPU driver to know which TPU - # chip is actually visible. Otherwise the TPU driver will fail to - # initialize because the number of devices would be different from - # the number of visible worker addresses. - os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) - os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) - - pjrt.initialize_multiprocess(local_rank, local_world_size) - xr._init_world_size_ordinal() - self.groups = create_optimized_replica_groups() - - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: - # TODO: Remove the groups specification after XLA compiler can support - # auto-reordering the ring order for all-reduce. - return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups) - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - assert dim == -1, "TPUs only support dim=-1 for all-gather." - return xm.all_gather(input_, dim=dim) diff --git a/vllm/distributed/tpu_distributed_utils.py b/vllm/distributed/tpu_distributed_utils.py deleted file mode 100644 index 4ff1f0ce4410a..0000000000000 --- a/vllm/distributed/tpu_distributed_utils.py +++ /dev/null @@ -1,188 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import OrderedDict -from typing import Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch_xla.distributed.spmd as xs -from torch.nn.parameter import Parameter - -from vllm.logger import init_logger -from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, -) - -logger = init_logger(__name__) - - -class XlaQKVParallelLinear(nn.Module): - def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None): - super().__init__() - assert isinstance(qkv_linear, QKVParallelLinear) - self.skip_bias_add = qkv_linear.skip_bias_add - self.return_bias = qkv_linear.return_bias - assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD." - - self.q_weight: Parameter - self.k_weight: Parameter - self.v_weight: Parameter - self.q_bias: Parameter | None - self.k_bias: Parameter | None - self.v_bias: Parameter | None - self._load_weights_from_qkv_linear(qkv_linear) - if mesh is not None: - self._shard_weight(mesh) - - def _shard_weight(self, mesh: "xs.Mesh"): - self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False) - self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False) - self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False) - xs.mark_sharding(self.q_weight, mesh, ("x", None)) - xs.mark_sharding(self.k_weight, mesh, ("x", None)) - xs.mark_sharding(self.v_weight, mesh, ("x", None)) - if self.q_bias is not None: - assert self.k_bias is not None and self.v_bias is not None, ( - "QKVParallelLinear should have q, k, and v biases together." - ) - self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.q_bias, mesh, ("x",)) - self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.k_bias, mesh, ("x",)) - self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False) - xs.mark_sharding(self.v_bias, mesh, ("x",)) - - def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module): - q_proj_size, k_proj_size, _ = qkv_linear.output_sizes - # The weight of qkv linear is a concatenation of q, k, and v weights - # along the output dimension. - qkv_weight = qkv_linear.weight.data.cpu() - q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False) - k_weight = Parameter( - qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False - ) - v_weight = Parameter( - qkv_weight[q_proj_size + k_proj_size :], requires_grad=False - ) - self.register_parameter("q_weight", q_weight) - self.register_parameter("k_weight", k_weight) - self.register_parameter("v_weight", v_weight) - - if qkv_linear.bias is not None: - q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False) - k_bias = Parameter( - qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size], - requires_grad=False, - ) - v_bias = Parameter( - qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False - ) - self.register_parameter("q_bias", q_bias) - self.register_parameter("k_bias", k_bias) - self.register_parameter("v_bias", v_bias) - else: - self.register_parameter("q_bias", None) - self.register_parameter("k_bias", None) - self.register_parameter("v_bias", None) - - def forward(self, input): - # Same forward functionality as QKVParallelLinear, but doing qkv porj - # separately. - q_bias = self.q_bias if not self.skip_bias_add else None - k_bias = self.k_bias if not self.skip_bias_add else None - v_bias = self.v_bias if not self.skip_bias_add else None - q_proj = F.linear(input, self.q_weight, q_bias) - k_proj = F.linear(input, self.k_weight, k_bias) - v_proj = F.linear(input, self.v_weight, v_bias) - # The q/k/v projections will be split outside of the QKVParallelLinear. - # Because we are replacing XlaQKVParallelLinear with the - # QKVParallelLinear, we need to concatenate q, k, and v projections to - # match the output shape of the QKVParallelLinear implementation even if - # it seems to be redundant. - # The concat and the following split will be noop, and should be - # optimized away by the compiler. - qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1) - output_bias = ( - torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None - ) - if not self.return_bias: - return qkv_proj - return qkv_proj, output_bias - - -def partition_column_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, ColumnParallelLinear) - xs.mark_sharding(layer.weight, mesh, ("x", None)) - logger.debug("Applied column-parallel sharding to %s", layer) - return layer - - -def partition_row_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, RowParallelLinear) - xs.mark_sharding(layer.weight, mesh, (None, "x")) - logger.debug("Applied row-parallel sharding to %s", layer) - return layer - - -def partition_qkv_parallel_linear( - layer: torch.nn.Module, mesh: xs.Mesh -) -> torch.nn.Module: - assert isinstance(layer, QKVParallelLinear) - xla_layer = XlaQKVParallelLinear(layer, mesh) - logger.debug("Applied qkv parallel sharding to %s", layer) - return xla_layer - - -MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict( - [ - ("QKVParallelLinear", partition_qkv_parallel_linear), - ("ColumnParallelLinear", partition_column_parallel_linear), - ("RowParallelLinear", partition_row_parallel_linear), - ] -) - - -def get_fqn(module): - # Get the fully qualified name of the module - return module.__class__.__qualname__ - - -def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None: - """ - Recursively check a PyTorch model and apply appropriate sharding based on - the MODULE_TYPE_TO_WRAPPING_FUNC mapping. - - Args: - model: torch.nn.Module to process - mesh: An XLA SPMD mesh object used for sharding - """ - - def _process_module(module, name=None, parent=None): - for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items(): - if get_fqn(module) == module_type: - wrapped_module = wrapping_func(module, mesh) - - assert parent is not None and name is not None, ( - "Top Level module is not expected to be wrapped." - ) - if wrapped_module is not module: - # Wrapped module and module are different py object. - # The original module should be replaced by the - # wrapped_module. - logger.debug("replace %s with %s", module, wrapped_module) - setattr(parent, name, wrapped_module) - - module = wrapped_module - break - - for child_name, child_module in list(module.named_children()): - _process_module(child_module, child_name, module) - - _process_module(model) diff --git a/vllm/lora/ops/xla_ops/__init__.py b/vllm/lora/ops/xla_ops/__init__.py deleted file mode 100644 index b5570ceca68ca..0000000000000 --- a/vllm/lora/ops/xla_ops/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink - -__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"] diff --git a/vllm/lora/ops/xla_ops/lora_ops.py b/vllm/lora/ops/xla_ops/lora_ops.py deleted file mode 100644 index 4924890b388cb..0000000000000 --- a/vllm/lora/ops/xla_ops/lora_ops.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import jax -import jax.numpy as jnp -import torch -import torch.nn.functional as F -import torch_xla.core.xla_builder as xb -from torch.library import impl -from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard - - -@jax.jit -def bgmv_jax(inputs, loras, idxs): - return jnp.einsum( - "td,tX,Xld->tl", - inputs, - jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype), - loras, - ) - - -XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor") - - -@impl(XLA_LIB, "bgmv", "XLA") -def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - - jax_import_guard() - return xb.call_jax(bgmv_jax, (inputs, loras, idxs)) - - -@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd") -def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor): - T, _ = inputs.shape - if len(loras.shape) == 4: - loras = loras.squeeze(axis=1) - _, L, _ = loras.shape - - return torch.empty((T, L), device=inputs.device) - - -def bgmv_expand( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - add_inputs: bool = True, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape - [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output - tensor. - """ - - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - limit = output_tensor.shape[0] - if outputs.shape[0] == 1 and output_tensor.shape[0] != 1: - limit = 1 - - if output_tensor.shape[1] > outputs.shape[1]: - outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0)) - - if add_inputs: - return output_tensor + outputs[:limit, : output_tensor.shape[1]] - else: - return outputs[:limit, : output_tensor.shape[1]] - - -def bgmv_shrink( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - lora_indices_tensor: torch.Tensor, - scaling: float = 1.0, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - scaling (float, optional): Scalar multiplier applied to the output. - """ - - return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - -def bgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - slice_offset: int, - slice_size: int, - add_inputs: bool = True, -): - """ - Args: - inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size]. - - lora_b_weights (torch.Tensor): LoRA weights of shape - [num_loras, lora_rank, hidden_size]. - - output_tensor (torch.Tensor): output tensor of shape - [num_tokens, hidden_size * num_slices]. - - lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens] - indicating which LoRA matrix to use for each token. - add_inputs (bool): Whether or not to add the input tensor to the output - tensor. - """ - outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor) - - outputs = F.pad( - outputs, - ( - slice_offset, - output_tensor.shape[1] - (slice_offset + slice_size), - 0, - 0, - ), - ) - - if add_inputs: - return output_tensor + outputs - else: - return outputs diff --git a/vllm/lora/punica_wrapper/punica_tpu.py b/vllm/lora/punica_wrapper/punica_tpu.py deleted file mode 100644 index 0888772db54e7..0000000000000 --- a/vllm/lora/punica_wrapper/punica_tpu.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from typing import TYPE_CHECKING - -import torch -import torch.nn.functional as F -import torch_xla - -from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink -from vllm.lora.punica_wrapper.utils import convert_mapping - -if TYPE_CHECKING: - # avoid circuit import - from vllm.lora.layers import LoRAMapping - -from .punica_base import PunicaWrapperBase - - -class PunicaWrapperTPU(PunicaWrapperBase): - """ - PunicaWrapperTPU is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for - Multi-LoRA, and to provide the interface for the pytorch punica ops. - """ - - def __init__( - self, - max_num_batched_tokens: int, - max_batches: int, - device: torch.device | str, - **kwargs, - ): - PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - - # PunicaWrapperBase defines some tensors with dtype=torch.int64, which - # isn't supported by the TPU. So convert those tensors to int32. - # Not all of them are used by the TPU so only convert the useful ones. - self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32) - self._sampler_indices = self._sampler_indices.to(dtype=torch.int32) - self._sampler_indices_padded = self._sampler_indices_padded.to( - dtype=torch.int32 - ) - - torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True) - torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True) - - torch._dynamo.mark_dynamic(self._token_lora_indices, 0) - torch._dynamo.mark_dynamic(self._embeddings_indices, 1) - torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0) - - def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor: - return torch.narrow(self._token_lora_indices, 0, 0, x.size(0)) - - @property - def embeddings_indices(self) -> torch.Tensor: - """ - This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA. - """ - return self._embeddings_indices[:] - - @property - def sampler_indices_padded(self) -> torch.Tensor: - """ - This property provides access to padded sampler indices. - """ - return self._sampler_indices_padded[:] - - def shrink( - self, - x: torch.Tensor, - w_t_all: torch.Tensor, - scale: float, - ): - return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale) - - def expand( - self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool - ): - return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs) - - def expand_slice( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: int, - y_slice_size: int, - add_inputs: bool, - ) -> torch.Tensor: - return bgmv_expand_slice( - x, - w_t_all, - y, - self._get_token_lora_indices(x), - y_offset, - y_slice_size, - add_inputs, - ) - - def add_shrink( - self, - y: tuple[torch.Tensor, ...] | torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - scale: float, - **kwargs, - ) -> torch.Tensor | None: - """ - Performs GEMM for multiple slices of lora_a. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += (x @ lora_a_stacked[i]) * scale - - Args: - y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors - x (torch.Tensor): Input tensor - lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights - scale (float): Scaling factor for the operation - """ - - torch.ops.xla.dynamo_set_buffer_donor_(y, True) - x = x.view(-1, x.shape[-1]) - - for slice_idx in range(len(lora_a_stacked)): - lora_s = lora_a_stacked[slice_idx] - y_s = self.shrink(x, lora_s, scale) - y[slice_idx, :, :] = y_s # type: ignore[index] - return y - - def add_expand( - self, - y: torch.Tensor, - x: tuple[torch.Tensor, ...] | torch.Tensor, - lora_b_stacked: tuple[torch.Tensor, ...], - output_slices: tuple[int, ...], - offset_start: int = 0, - add_inputs=True, - **kwargs, - ) -> torch.Tensor: - """ - Performs GEMM for multiple slices of lora_b. - - Semantics: - for i in range(len(lora_b_stacked)): - slice = output_slices[i] - y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] - offset += slice - - Args: - y (torch.Tensor): Output tensor. - x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors - lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight - output_slices (tuple[int, ...]): Every slice's size - add_inputs (bool): Defaults to True. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - offset_left = 0 - - for slice_idx in range(len(lora_b_stacked)): - y = self.expand_slice( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_left += output_slices[slice_idx] - return y.view_as(y_org) - - def add_lora_embedding( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_b_stacked: torch.Tensor, - add_inputs: bool = True, - **kwargs, - ) -> torch.Tensor: - """ - Applies lora specifically for VocabParallelEmbeddingWithLoRA. - - Semantics: - y += x @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_b_stacked (torch.Tensor): lora_b's weights. - add_inputs (bool): Default to True. - """ - - # Embedding layer only needs the expand op - return self.expand(y, x, lora_b_stacked, add_inputs) - - def add_lora_linear( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: tuple[torch.Tensor, ...], - lora_b_stacked: tuple[torch.Tensor, ...], - scale: float, - output_slices: tuple[int, ...], - *, - buffer: tuple[torch.Tensor, ...] | None = None, - **kwargs, - ) -> torch.Tensor: - """ - Applicable to linear-related lora. - - Semantics: - for i in range(len(lora_a_stacked)): - y[i] += ( - x[i].unsqueeze(0) - @ lora_a_stacked[indices[i], layer_idx, :, :] - @ lora_b_stacked[indices[i], layer_idx, :, :] - * scale - ).squeeze(0) - - Args: - y (torch.Tensor): Output tensor. Will not be changed in-place. - x (torch.Tensor): Input tensor (T, E) - lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight. - lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight. - scale (float): Scaling factor. - output_slices (tuple[int, ...]): Every slice's size. - buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None. - """ - - assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) - - if buffer is None: - r = lora_b_stacked[0].size(-1) - T = x.size(0) - buffer = torch.zeros( - (len(output_slices), T, r), - dtype=x.dtype, - device=x.device, - ) - buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) - return self.add_expand( - y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs - ) - - def add_lora_logits( - self, - y: torch.Tensor, - x: torch.Tensor, - lora_a_stacked: torch.Tensor, - lora_b_stacked: torch.Tensor, - scale, - *, - buffer: torch.Tensor | None = None, - **kwargs, - ) -> torch.Tensor: - """ - Applies lora specifically for LogitsProcessorWithLoRA. - - Semantics: - buffer = (x @ lora_a_stacked) * scale - y += buffer @ lora_b_stacked - - Args: - y (torch.Tensor): Output tensor. - x (torch.Tensor): Input tensor. - lora_a_stacked (torch.Tensor): lora_a's weights. - lora_b_stacked (torch.Tensor):lora_b's weights. - scale (float): Scaling factor. - buffer (Optional[torch.Tensor]):Default to None. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - x = x.view(-1, x.shape[-1]) - - sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0)) - buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale) - y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True) - return y.view_as(y_org) - - # This performs the same tensor ops as the base method, except it does them - # on the CPU then transfers the results to the TPU - def _update_base_metadata( - self, - mapping: "LoRAMapping", - lora_index_to_id: list[int | None], - max_loras: int, - vocab_size: int, - ): - # Make sure we don't accidentally collect outside operations - torch_xla.sync() - - # Pad the prompt mapping to avoid running into recompiles on the TPU - # TODO: Should this happen inside mapping internally? If so how can we - # avoid having backend specific LoRAMapping classes? - mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping) - - ( - base_indices, - sampler_indices, - sampler_indices_padded, - embeddings_indices, - indices_len, - ) = convert_mapping( - mapping, - lora_index_to_id, - max_loras, - vocab_size, - 0, # extra_vocab_size - "cpu", - ) - self._token_lora_indices = self._pad_to_shape( - base_indices, self._token_lora_indices.shape, dims=1 - ).to(self.device) - self._sampler_indices = self._pad_to_shape( - sampler_indices, self._sampler_indices.shape, dims=1 - ).to(self.device) - self._sampler_indices_padded = self._pad_to_shape( - sampler_indices_padded, self._sampler_indices_padded.shape, dims=1 - ).to(self.device) - self._embeddings_indices = self._pad_to_shape( - embeddings_indices, self._embeddings_indices.shape, dims=2 - ).to(self.device) - self.indices_len[:] = indices_len - - def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None: - self.batch_size = 1 - self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[ - : self.batch_size - ] - - def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]: - num_reqs = len(prompt_mapping) - - # From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular - # import - MIN_NUM_SEQS = 8 - - padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS) - pad_len = padded_num_reqs - num_reqs - - padding = [-1] * pad_len - return tuple(list(prompt_mapping) + padding) - - def _pad_to_shape(self, src, target_shape, dims=1): - if dims == 1: - pad_len = target_shape[0] - src.shape[0] - return F.pad(src, (0, pad_len), value=0).to(torch.int32) - else: - pad_rows = target_shape[0] - src.shape[0] - pad_cols = target_shape[1] - src.shape[1] - return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2e7267d56d838..d6226da76eaed 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -71,11 +71,6 @@ from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: rocm_aiter_grouped_topk, ) -if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas -else: - fused_moe_pallas = None # type: ignore - from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py deleted file mode 100644 index 66c00cf89873a..0000000000000 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - - -def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: - """ - Compute the histogram of an int32 tensor. The bin edges are defined by the - min and max values, with step = 1. - """ - assert input.dtype == torch.int32, "input must be of torch.int32 dtype." - assert min <= max, "min must be less than or equal to max." - - def searchsorted( - sorted_sequence: torch.Tensor, values_to_search: torch.Tensor - ) -> torch.Tensor: - return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1) - - bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to( - input.device - ) - return searchsorted(bin_edges, input).to(torch.int32) - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - renormalize: bool = False, -) -> torch.Tensor: - """ - Args: - hidden_states: [*, hidden_size] - w1: [num_experts, intermediate_size * 2, hidden_size] - w2: [num_experts, hidden_size, intermediate_size] - gating_output: [*, num_experts] - """ - assert expert_map is None, "expert_map is not supported for pallas MoE." - import torch_xla.experimental.custom_kernel # noqa: F401 - - orig_shape = hidden_states.shape - hidden_size = hidden_states.shape[-1] - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - intermediate_size = w2.shape[-1] - device = hidden_states.device - dtype = hidden_states.dtype - assert (num_tokens * topk) % 16 == 0, ( - "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " - f"16 but got {num_tokens * topk}" - ) - - hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, num_experts) - topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) - topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(dtype) - - topk_indices = topk_indices.flatten() - topk_argsort_indices = topk_indices.argsort() - topk_argsort_revert_indices = topk_argsort_indices.argsort() - token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk) - token_indices = token_indices[topk_argsort_indices] - group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) - - x = hidden_states[token_indices] - x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True) - x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] - x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True) - x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) - - x = x * topk_weights.unsqueeze(dim=-1) - x = x.sum(dim=-2) - x = x.reshape(orig_shape) - return x diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 82dbccf3fa9da..4c03cff2e8131 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -38,10 +38,6 @@ if current_platform.is_cuda_alike(): else: TritonExperts = None # type: ignore -if current_platform.is_tpu(): - from .moe_pallas import fused_moe as fused_moe_pallas -else: - fused_moe_pallas = None # type: ignore logger = init_logger(__name__) @@ -403,53 +399,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=layer.custom_routing_function, ) - def forward_tpu( - self, - layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 - x: torch.Tensor, - router_logits: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert not layer.use_grouped_topk - assert layer.num_expert_group is None - assert layer.topk_group is None - assert layer.custom_routing_function is None - assert layer.apply_router_weight_on_input is False - if layer.scoring_func != "softmax": - raise NotImplementedError( - "Only softmax scoring function is supported for TPU." - ) - if layer.e_score_correction_bias is not None: - raise NotImplementedError( - "Expert score correction bias is not supported for TPU." - ) - assert layer.activation == "silu", ( - f"{layer.activation} is not supported for TPU." - ) - assert layer.routed_scaling_factor == 1.0, ( - f"routed_scaling_factor {layer.routed_scaling_factor} is " - "not supported for TPU." - ) - if ( - layer.enable_eplb is not False - or layer.expert_load_view is not None - or layer.logical_to_physical_map is not None - or layer.logical_replica_count is not None - ): - raise NotImplementedError("Expert load balancing is not supported for TPU.") - return fused_moe_pallas( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk=layer.top_k, - gating_output=router_logits, - global_num_experts=layer.global_num_experts, - expert_map=layer.expert_map, - renormalize=layer.renormalize, - ) - - if current_platform.is_tpu(): - forward_native = forward_tpu - elif current_platform.is_cpu(): + if current_platform.is_cpu(): forward_native = forward_cpu elif current_platform.is_xpu(): forward_native = forward_xpu diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 18aaae394f935..1a4378f5df3db 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -130,12 +130,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: from .ptpc_fp8 import PTPCFp8Config from .rtn import RTNConfig from .torchao import TorchAOConfig - from .tpu_int8 import Int8TpuConfig method_to_config: dict[str, type[QuantizationConfig]] = { "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "fp_quant": FPQuantConfig, diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index 20d050d387d49..4ccc4182367a6 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( TritonScaledMMLinearKernel, ) -from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( - XLAScaledMMLinearKernel, -) from vllm.platforms import PlatformEnum, current_platform # in priority/performance order (when available) @@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], - PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py deleted file mode 100644 index 0be858c51993d..0000000000000 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +++ /dev/null @@ -1,106 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import warnings - -import torch -from functorch.experimental.control_flow import cond # noqa: F401 - -from vllm.model_executor.layers.quantization.utils import replace_parameter -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, -) -from vllm.platforms import current_platform - -from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig - - -class XLAScaledMMLinearKernel(ScaledMMLinearKernel): - @classmethod - def is_supported( - cls, compute_capability: int | None = None - ) -> tuple[bool, str | None]: - if not current_platform.is_tpu(): - return False, "Requires TPU." - return True, None - - @classmethod - def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: - if not current_platform.is_tpu(): - return False, "ScaledMMXLA requires running on TPU." - - if c.is_static_input_scheme: - return False, "ScaledMMXLA requires dynamic activation scales." - - if not c.input_symmetric: - return False, "ScaledMMXLA requires symmetric activation scales." - - if not c.is_channelwise: - return False, "ScaledMMXLA requires channelwise weight scales" - - return True, None - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # WEIGHT - # [out, in] (different than cutlass_scaled_mm) - weight = getattr(layer, self.w_q_name) - replace_parameter( - layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False) - ) - - # WEIGHT SCALE - # XLA kernels support only per-tensor and per-channel. - # If we have a fused module (QKV, MLP) with per tensor scales (thus N - # scales being passed to the kernel), convert to the per-channel case. - is_fused_module = len(layer.logical_widths) > 1 - weight_scale = getattr(layer, self.w_s_name) - if is_fused_module and not self.config.is_channelwise: - weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) - - # [out_channel,] (different than cutlass_scaled_mm) - weight_scale = weight_scale.squeeze(-1) - replace_parameter( - layer, - self.w_s_name, - torch.nn.Parameter(weight_scale.data, requires_grad=False), - ) - - # Only support symmetric dynamic activation quantization. - setattr(layer, self.i_s_name, None) - setattr(layer, self.i_zp_name, None) - setattr(layer, self.azp_adj_name, None) - - # Filter warning for cond usage in apply_weights. It is okay - # to specialize the graph since bias is not dynamic. - warnings.filterwarnings( - "ignore", - message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501 - ) - - def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): - return x - - def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None): - return x + bias - - def apply_weights( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - w_q, w_s, _, _, _ = self._get_weight_params(layer) - - # Required to register custom ops. - import torch_xla.experimental.custom_kernel # noqa: F401 - - out = torch.ops.xla.quantized_matmul_int8( - x, - w_q, - w_s, - quantize_activation=True, - ) - - # Explicitly capture control flow to make dynamo happy. - # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 - return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py deleted file mode 100644 index 64bfa8fb80eb2..0000000000000 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from typing import Any, Optional - -import torch -from torch.nn import Module -from torch.nn.parameter import Parameter - -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase -from vllm.model_executor.layers.quantization import ( - QuantizationConfig, - QuantizationMethods, -) -from vllm.model_executor.parameter import ModelWeightParameter - -ACTIVATION_SCHEMES = ["none", "dynamic"] - - -class Int8TpuConfig(QuantizationConfig): - """Int8 Quantization Config class for TPU Backend.""" - - def __init__( - self, - activation_scheme: str = "none", - ) -> None: - super().__init__() - if activation_scheme not in ACTIVATION_SCHEMES: - raise ValueError(f"Unsupported activation scheme {activation_scheme}") - self.activation_scheme = activation_scheme - - def get_name(self) -> QuantizationMethods: - return "tpu_int8" - - def get_supported_act_dtypes(self) -> list[torch.dtype]: - return [torch.float16, torch.bfloat16] - - @classmethod - def get_min_capability(cls) -> int: - raise NotImplementedError("This function should not be called with TPU Backend") - - @staticmethod - def get_config_filenames() -> list[str]: - return [] - - @classmethod - def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": - activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) - return cls(activation_scheme=activation_scheme) - - def get_quant_method( - self, layer: Module, prefix: str - ) -> Optional["TPUInt8LinearMethod"]: - if isinstance(layer, LinearBase): - return TPUInt8LinearMethod(self) - return None - - -class TPUInt8LinearMethod(LinearMethodBase): - """Int8 Linear method for TPU Quant.""" - - def __init__(self, quant_config: Int8TpuConfig): - self.quant_config = quant_config - self.quantize_activation = False - if self.quant_config.activation_scheme == "dynamic": - self.quantize_activation = True - - def create_weights( - self, - layer: Module, - input_size_per_partition: int, - output_partition_sizes: list[int], - input_size: int, - output_size: int, - params_dtype: torch.dtype, - **extra_weight_attrs, - ): - weight_loader = extra_weight_attrs.get("weight_loader") - weight = ModelWeightParameter( - data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype, - ), - input_dim=1, - output_dim=0, - weight_loader=weight_loader, - ) - layer.register_parameter("weight", weight) - - def _quantize_weight( - self, weight: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - weight_dtype = weight.dtype - weight = weight.cpu().to(torch.float32) - n_bit = 8 - eps = 1e-5 - max_int = 2 ** (n_bit - 1) - 1 - min_int = -(2 ** (n_bit - 1)) - max_val = weight.abs().amax(dim=-1, keepdim=True) - max_val = max_val.clamp(min=eps) - qscale = max_val / max_int - qweight = torch.clamp( - torch.round(weight * (1.0 / qscale)), min_int, max_int - ).to(torch.int8) - qscale = qscale.squeeze().to(weight_dtype) - return qweight, qscale - - def process_weights_after_loading(self, layer: Module) -> None: - layer.weight = Parameter(layer.weight.data, requires_grad=False) - device = layer.weight.device - qweight, qscale = self._quantize_weight(layer.weight) - qweight = qweight.to(device) - qscale = qscale.to(device) - layer.weight = Parameter(qweight, requires_grad=False) - layer.scale = Parameter(qscale, requires_grad=False) - - def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: torch.Tensor | None = None, - ) -> torch.Tensor: - try: - import torch_xla.experimental.custom_kernel # noqa: F401 - except ImportError as err: - raise ImportError( - "Please install torch_xla by following the instructions at " - "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 - "to run vLLM on TPU." - ) from err - weight = layer.weight - scale = layer.scale - out = torch.ops.xla.quantized_matmul_int8( - x, weight, scale, quantize_activation=self.quantize_activation - ) - if bias is not None: - out = out + bias - return out diff --git a/vllm/model_executor/model_loader/tpu.py b/vllm/model_executor/model_loader/tpu.py deleted file mode 100644 index fc142f1f07fae..0000000000000 --- a/vllm/model_executor/model_loader/tpu.py +++ /dev/null @@ -1,118 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import time - -import torch -import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.distributed.spmd as xs - -from vllm.config import ModelConfig, VllmConfig -from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model -from vllm.logger import init_logger -from vllm.model_executor.model_loader.default_loader import DefaultModelLoader -from vllm.model_executor.model_loader.utils import ( - initialize_model, - process_weights_after_loading, -) -from vllm.utils.torch_utils import set_default_torch_dtype - -logger = init_logger(__name__) - - -class TPUModelLoader(DefaultModelLoader): - """ - A TPU model loader for model loading under SPMD mode. - """ - - def load_model( - self, - vllm_config: VllmConfig, - model_config: ModelConfig, - mesh: xs.Mesh | None = None, - ) -> nn.Module: - # Initialize model and load weights on CPU. Then, during SPMD partition, - # weights are sharded and transferred to TPUs. - self.counter_before_loading_weights = time.perf_counter() - model_config = vllm_config.model_config - assert model_config.quantization is None, "Quantization not supported" - target_device = torch.device("cpu") - with set_default_torch_dtype(model_config.dtype): - with target_device: - model = initialize_model(vllm_config=vllm_config) - - load_format = vllm_config.load_config.load_format - if load_format != "dummy": - weights_to_load = {name for name, _ in model.named_parameters()} - all_weights = self.get_all_weights(model_config, model) - loaded_weights = model.load_weights(all_weights) - self.counter_after_loading_weights = time.perf_counter() - logger.info( - "Loading weights took %.2f seconds", - self.counter_after_loading_weights - - self.counter_before_loading_weights, - ) - # We only enable strict check for non-quantized models - # that have loaded weights tracking currently. - if model_config.quantization is None and loaded_weights is not None: - weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError( - "Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}" - ) - else: - logger.info("Use dummy weight during weight loading.") - - process_weights_after_loading(model, model_config, target_device) - - counter_before_partition = time.perf_counter() - model = model.eval() - model = model.to("xla") - shard_model(model, mesh) - counter_after_partition = time.perf_counter() - logger.info( - "Partition model took %.2f seconds", - counter_after_partition - counter_before_partition, - ) - - # Ensure the model is properly loaded. - self._check_model_is_loaded(mesh, model) - - # Need to torch compile after model sharding are done. Because the - # compiler hints ('xs.mark_sharding') are torch ops. - if not model_config.is_multimodal_model: - model.model = torch.compile(model.model, backend="openxla") - else: - model.language_model.model = torch.compile( - model.language_model.model, backend="openxla" - ) - return model - - def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None: - """ - Ensure the model is properly loaded. - 1. All model parameters and buffers are on XLA device. - 2. Non-SPMD friendly layers are replaced as expected. - """ - device = xm.xla_device() - device_type = str(device.type) - - # Check parameters - for name, param in model.named_parameters(): - assert param.device.type == device_type, ( - f"Parameter {name} is on {param.device.type} instead of {device_type}" - ) - - # Check buffers - for name, buffer in model.named_buffers(): - assert buffer.device.type == device_type, ( - f"Buffer {name} is on {buffer.device.type} instead of {device_type}" - ) - - for module in model.modules(): - if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"): - raise AssertionError( - "QKVParallelLinear should be replaced by \ - XlaQKVParallelLinear under SPMD mode." - ) diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 69226763aafe6..b0886bba8a22a 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -186,20 +186,6 @@ class UsageMessage: except Exception: return False - def _report_torch_xla_usage(self) -> bool: - try: - import torch_xla - - self.gpu_count = torch_xla.runtime.world_size() - self.gpu_type = torch_xla.tpu.get_tpu_type() - self.gpu_memory_per_device = torch_xla.core.xla_model.get_memory_info()[ - "bytes_limit" - ] - self.cuda_runtime = "torch_xla" - return True - except Exception: - return False - def _report_usage_once( self, model_architecture: str, @@ -217,9 +203,7 @@ class UsageMessage: if current_platform.is_cuda(): self.cuda_runtime = torch.version.cuda if current_platform.is_tpu(): # noqa: SIM102 - if (not self._report_tpu_inference_usage()) and ( - not self._report_torch_xla_usage() - ): + if not self._report_tpu_inference_usage(): logger.exception("Failed to collect TPU information") self.provider = _detect_cloud_provider() self.architecture = platform.machine() diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 5f7f8e81a24c4..e5a0cf7420497 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -35,7 +35,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { import tpu_inference # noqa: F401 - +# Note(weiyulin): some static functions are still used by tpu-inference class PallasAttentionBackend(AttentionBackend): @staticmethod def get_name() -> str: @@ -314,60 +314,3 @@ def write_to_kv_cache( ) # NOTE: the in-place copy will be optimized away by XLA compiler. kv_cache.copy_(new_kv_cache) - - -# We can move this function to a common utils file if it's also useful for other -# hardware. -def dtype_bits(dtype: torch.dtype): - if dtype.is_floating_point: - try: - return torch.finfo(dtype).bits - except TypeError: - pass - elif dtype.is_complex: - if dtype is torch.complex32: - return 32 - elif dtype is torch.complex64: - return 64 - elif dtype is torch.complex128: - return 128 - else: - try: - return torch.iinfo(dtype).bits - # torch.iinfo cannot support int4, int2, bits8... - except TypeError: - pass - str_dtype = str(dtype) - # support torch.int4, torch.int5, torch.uint5... - if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"): - return int(str_dtype[-1]) - raise TypeError(f"Getting the bit width of {dtype} is not supported") - - -def get_dtype_packing(dtype): - bits = dtype_bits(dtype) - if 32 % bits != 0: - raise ValueError( - f"The bit width must be divisible by 32, but got bits={bits}, " - "dtype={dtype}" - ) - return 32 // bits - - -def get_page_size_bytes( - block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype -) -> int: - """Returns the size in bytes of one page of the KV cache.""" - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - num_combined_kv_heads = num_kv_heads * 2 - - # NOTE: for the implicit padding in XLA - packing = get_dtype_packing(kv_cache_dtype) - num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing - - kv_cache_dtype_bits = dtype_bits(kv_cache_dtype) - return ( - block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8 - ) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py deleted file mode 100644 index c7404c4642d7e..0000000000000 --- a/vllm/v1/worker/tpu_model_runner.py +++ /dev/null @@ -1,2191 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import bisect -import gc -import time -from typing import TYPE_CHECKING, Any, cast -from unittest.mock import patch - -import numpy as np -import torch -import torch.nn as nn - -# TPU XLA related -import torch_xla -import torch_xla.core.xla_model as xm -import torch_xla.distributed.spmd as xs -import torch_xla.runtime as xr - -import vllm.envs as envs -from vllm.attention.backends.abstract import AttentionType -from vllm.attention.layer import Attention, MLAAttention -from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention -from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper -from vllm.config import ( - ParallelConfig, - VllmConfig, - get_layers_from_vllm_config, - update_config, -) -from vllm.distributed.kv_transfer import get_kv_transfer_group, has_kv_transfer_group -from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks -from vllm.forward_context import set_forward_context -from vllm.logger import init_logger -from vllm.lora.layers import BaseLayerWithLoRA -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.model_loader import get_model_loader -from vllm.model_executor.model_loader.tpu import TPUModelLoader -from vllm.model_executor.models.interfaces import ( - SupportsMultiModal, - supports_transcription, -) -from vllm.model_executor.models.interfaces_base import ( - is_pooling_model, - is_text_generation_model, -) -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import ( - BatchedTensorInputs, - MultiModalKwargsItem, - PlaceholderRange, -) -from vllm.multimodal.utils import group_mm_kwargs_by_modality -from vllm.sequence import IntermediateTensors -from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.utils.math_utils import cdiv, prev_power_of_2 -from vllm.utils.platform_utils import is_pin_memory_available -from vllm.v1.attention.backends.pallas import ( - TPU_STR_DTYPE_TO_TORCH_DTYPE, - PallasAttentionBackend, - PallasMetadata, - get_page_size_bytes, -) -from vllm.v1.kv_cache_interface import ( - AttentionSpec, - FullAttentionSpec, - KVCacheConfig, - KVCacheSpec, - MLAAttentionSpec, - SlidingWindowSpec, -) -from vllm.v1.outputs import ( - EMPTY_MODEL_RUNNER_OUTPUT, - LogprobsLists, - LogprobsTensors, - ModelRunnerOutput, -) -from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata -from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler -from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, - KVConnectorOutput, -) -from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin -from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch - -from .utils import ( - MultiModalBudget, - add_kv_sharing_layers_to_kv_cache_groups, - bind_kv_cache, - sanity_check_mm_encoder_outputs, -) - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput - -logger = init_logger(__name__) - -INVALID_TOKEN_ID = -1 -# Smallest output size -MIN_NUM_SEQS = 8 - - -######################################################### -# Ways to avoid recompilation -######################################################### -# -# The model executor has two primary components: -# 1. preparing the model and sampler inputs -# 2. executing the model and sampler. -# The core idea is to avoid any TPU computation during input preparation. For -# better compilation tracking and increased flexibility, the model execution and -# sampler are divided into several distinct components. -# -# Below are the detailed steps: -# -# Step 1 -# It is recommended to avoid TPU operations when preparing the model and sampler -# inputs. CPU tensors can be prepared and transferred to the XLA device using -# cpu_tensor.to(xla_device), which only triggers CPU to TPU transfers and avoids -# compilation. -# -# Step 2 -# The TPU execution should be decomposed into subgraphs (4 at the moment): -# 1. the main model -# 2. selecting hidden states for each request -# 3. sampler -# 4. encoder. -# Each subgraph should be decorated in a torch.compile. This is used to make -# sure that we have the same subgraph topology in both dummy_run and -# xecute_model. The results from these subgraphs should either be passed to -# other subgraphs, or transferred from TPU to CPU using xla_tensor.cpu() for -# subsequent processing on the CPU. -# -# Step 3 -# The dummy_run should be comprehensive, ensuring all potential input shapes and -# branch predictions are included as subgraph inputs to facilitate -# pre-compilation. -class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): - def __init__( - self, - vllm_config: VllmConfig, - device: torch.device, - original_parallel_config: ParallelConfig | None = None, - ): - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.original_parallel_config = original_parallel_config - self.scheduler_config = vllm_config.scheduler_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - self.device_config = vllm_config.device_config - - model_config = self.model_config - cache_config = self.cache_config - scheduler_config = self.scheduler_config - parallel_config = self.parallel_config - self.device = device - self.check_recompilation = envs.VLLM_XLA_CHECK_RECOMPILATION - - # SPMD Related - self.use_spmd = envs.VLLM_XLA_USE_SPMD - if self.use_spmd: - num_devices = xr.global_runtime_device_count() - mesh_shape = (num_devices, 1) - device_ids = np.array(range(num_devices)) - self.mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) - - self.enforce_eager = model_config.enforce_eager - - self.num_xla_graphs = 0 - self._update_num_xla_graphs("init") - - self.pin_memory = is_pin_memory_available() - self.dtype = self.model_config.dtype - if cache_config.cache_dtype == "auto": - model_dtype = self.dtype - if isinstance(model_dtype, str): - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype] - else: - self.kv_cache_dtype = model_dtype - else: - self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - self._hidden_states_dtype = self.dtype - - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.max_model_len = model_config.max_model_len - self.most_model_len = envs.VLLM_TPU_MOST_MODEL_LEN - self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size) - self.num_blocks_per_most_len_req = ( - cdiv(self.most_model_len, self.block_size) - if self.most_model_len is not None - else None - ) - # InputBatch needs to work with sampling tensors greater than padding - # to avoid dynamic shapes. Also, avoid suboptimal alignment. - self.max_num_reqs = max(scheduler_config.max_num_seqs, MIN_NUM_SEQS) - self.num_tokens_paddings = _get_token_paddings( - min_token_size=16, - max_token_size=scheduler_config.max_num_batched_tokens, - padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP, - ) - # In case `max_num_tokens < max(num_tokens_paddings)` use the actual - # padded max value to pre-allocate data structures and pre-compile. - self.max_num_tokens = self.num_tokens_paddings[-1] - - # Model-related. - self.num_attn_layers = model_config.get_num_layers_by_block_type( - parallel_config, "attention" - ) - self.num_query_heads = model_config.get_num_attention_heads(parallel_config) - self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) - self.head_size = model_config.get_head_size() - self.inputs_embeds_size = model_config.get_inputs_embeds_size() - self.vocab_size = model_config.get_vocab_size() - - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.uses_mrope = model_config.uses_mrope - self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( - model_config - ) - # TODO: Support M-RoPE (e.g, Qwen2-VL) - assert not self.uses_mrope, "TPU does not support M-RoPE yet." - - self._num_slices_per_kv_cache_update_block = ( - _get_num_slices_per_kv_cache_update_block( - get_page_size_bytes( - block_size=self.block_size, - num_kv_heads=self.num_kv_heads, - head_size=self.head_size, - kv_cache_dtype=self.kv_cache_dtype, - ) - ) - ) - - # Lazy initialization - self.model: nn.Module # Set after load_model - self.kv_caches: list[torch.Tensor] = [] - # mm_hash -> encoder_output - self.encoder_cache: dict[str, torch.Tensor] = {} - - # Request states. - self.requests: dict[str, CachedRequestState] = {} - # NOTE(rob): num_prompt_logprobs only includes reqs - # that are currently in the prefill phase. - self.num_prompt_logprobs: dict[str, int] = {} - - # Initialize input batch early to avoid AttributeError in _update_states - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[self.block_size], - kernel_block_sizes=[self.cache_config.block_size], - ) - - # Cached torch/numpy tensor - # The pytorch tensor and numpy array share the same buffer. - # Sometimes the numpy op is faster so we create both. - self.input_ids_cpu = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device="cpu" - ) - - self.positions_cpu = torch.zeros( - self.max_num_tokens, dtype=torch.int32, device="cpu" - ) - self.positions_np = self.positions_cpu.numpy() - self.block_table_cpu = torch.zeros( - (self.max_num_reqs, self.max_num_blocks_per_req), - dtype=torch.int32, - device="cpu", - ) - # adjust num_reqs to avoid SMEM OOM. - self.num_reqs_most_model_len = ( - min( - PallasAttentionBackend.get_max_num_seqs( - self.most_model_len, self.block_size - ), - self.max_num_reqs, - ) - if self.most_model_len is not None - else None - ) - self.num_reqs_max_model_len = min( - PallasAttentionBackend.get_max_num_seqs( - self.max_model_len, self.block_size - ), - self.max_num_reqs, - ) - self.query_start_loc_cpu = torch.zeros( - self.max_num_tokens + 1, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.query_start_loc_np = self.query_start_loc_cpu.numpy() - - self.seq_lens_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.seq_lens_np = self.seq_lens_cpu.numpy() - - # Only relevant for multimodal models - if self.supports_mm_inputs: - self.is_mm_embed_cpu = torch.zeros( - self.max_num_tokens, - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory, - ) - - # Range tensor with values [0 .. self.max_num_tokens - 1]. - # Used to initialize positions / context_lens / seq_lens - # Keep in int64 to avoid overflow with long context - self.arange_np = np.arange(self.max_num_tokens, dtype=np.int64) - self.num_reqs_paddings = _get_req_paddings( - min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs - ) - - # Layer pairings for cross-layer KV sharing. - # If an Attention layer `layer_name` is in the keys of this dict, it - # means this layer will perform attention using the keys and values - # from the KV cache of `shared_kv_cache_layers[layer_name]`. - self.shared_kv_cache_layers: dict[str, str] = {} - - # tensors for structured decoding - self.grammar_bitmask_cpu = torch.zeros( - (self.max_num_reqs, cdiv(self.vocab_size, 32)), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory, - ) - self.require_structured_out_cpu = torch.zeros( - (self.max_num_reqs, 1), - dtype=torch.bool, - device="cpu", - pin_memory=self.pin_memory, - ) - self.structured_decode_arange = torch.arange( - 0, 32, device="cpu", pin_memory=self.pin_memory - ) - - self.mm_budget = ( - MultiModalBudget( - self.model_config, - self.scheduler_config, - self.mm_registry, - ) - if self.supports_mm_inputs - else None - ) - - if not self.use_spmd: - self.sample_from_logits_func = torch.compile( - self.sample_from_logits, - backend="openxla", - fullgraph=True, - dynamic=False, - ) - else: - self.sample_from_logits_func = self.sample_from_logits - - # For passing scheduler_output between successive - # execute_model() and sample_tokens() calls. - self.scheduler_output: SchedulerOutput | None = None - self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None - - def reset_mm_cache(self) -> None: - if self.mm_budget: - self.mm_budget.reset_cache() - - def _update_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - total_cached_graphs = xr.get_num_cached_compilation_graph() - new_compiled_graphs = total_cached_graphs - self.num_xla_graphs - if new_compiled_graphs == 0: - return - - logger.info( - "Add new %d compiled XLA graphs due to %s", new_compiled_graphs, case_str - ) - self.num_xla_graphs += new_compiled_graphs - - def _verify_num_xla_graphs(self, case_str): - check_comp = self.check_recompilation and not self.enforce_eager - if not check_comp: - return - - curr_cached_graph = xr.get_num_cached_compilation_graph() - assert self.num_xla_graphs == curr_cached_graph, ( - "Recompilation after warm up is detected during {}." - " num_xla_graphs = {} curr_cached_graph = {}".format( - case_str, self.num_xla_graphs, curr_cached_graph - ) - ) - - def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: - """Update the cached states and the persistent batch with the scheduler - output. - - The updated states are used by the `_prepare_inputs` function to create - the input GPU tensors for the model. - - Returns: - True if there is a new/resumed/paused/finished request. - If False, we can skip copying SamplingMetadata to the GPU. - """ - # Remove finished requests from the cached states. - for req_id in scheduler_output.finished_req_ids: - self.requests.pop(req_id, None) - self.num_prompt_logprobs.pop(req_id, None) - - # Remove the finished requests from the persistent batch. - # NOTE(woosuk): There could be an edge case where finished_req_ids and - # scheduled_req_ids overlap. This happens when a request is aborted and - # then resubmitted with the same ID. In this case, we treat them as two - # distinct requests - clearing the cached states for the first request - # and handling the second as a new request. - removed_req_indices: list[int] = [] - for req_id in scheduler_output.finished_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) - - # Free the cached encoder outputs. - for mm_hash in scheduler_output.free_encoder_mm_hashes: - self.encoder_cache.pop(mm_hash, None) - - # Remove the unscheduled requests from the persistent batch. - # NOTE(woosuk): The unscheduled requests are either preempted requests - # or running requests that are not scheduled in this step. We remove - # them from the persistent batch but keep their cached states since - # they will be scheduled again sometime in the future. - scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys() - cached_req_ids = self.input_batch.req_id_to_index.keys() - unscheduled_req_ids = cached_req_ids - scheduled_req_ids - # NOTE(woosuk): The persistent batch optimization assumes that - # consecutive batches contain mostly the same requests. If batches - # have low request overlap (e.g., alternating between two distinct - # sets of requests), this optimization becomes very inefficient. - for req_id in unscheduled_req_ids: - req_index = self.input_batch.remove_request(req_id) - assert req_index is not None - removed_req_indices.append(req_index) - - req_ids_to_add: list[str] = [] - # Add new requests to the cached states. - for new_req_data in scheduler_output.scheduled_new_reqs: - assert new_req_data.sampling_params is not None, ( - "Pooling is not supported in TPU yet" - ) - req_id = new_req_data.req_id - sampling_params = new_req_data.sampling_params - - self.requests[req_id] = CachedRequestState( - req_id=req_id, - prompt_token_ids=new_req_data.prompt_token_ids, - prompt_embeds=new_req_data.prompt_embeds, - mm_features=new_req_data.mm_features, - sampling_params=sampling_params, - pooling_params=None, - generator=None, - block_ids=new_req_data.block_ids, - num_computed_tokens=new_req_data.num_computed_tokens, - output_token_ids=[], - lora_request=new_req_data.lora_request, - ) - - if sampling_params and sampling_params.prompt_logprobs is not None: - self.num_prompt_logprobs[req_id] = ( - self.input_batch.vocab_size - if sampling_params.prompt_logprobs == -1 - else sampling_params.prompt_logprobs - ) - - req_ids_to_add.append(req_id) - - # Update the states of the running/resumed requests. - req_data = scheduler_output.scheduled_cached_reqs - for i, req_id in enumerate(req_data.req_ids): - req_state = self.requests[req_id] - num_computed_tokens = req_data.num_computed_tokens[i] - new_block_ids = req_data.new_block_ids[i] - resumed_from_preemption = req_id in req_data.resumed_req_ids - - # Update the cached states. - req_state.num_computed_tokens = num_computed_tokens - if not resumed_from_preemption: - if new_block_ids is not None: - # Append the new blocks to the existing block IDs. - for block_ids, new_ids in zip(req_state.block_ids, new_block_ids): - block_ids.extend(new_ids) - else: - assert new_block_ids is not None - # The request is resumed from preemption. - # Replace the existing block IDs with the new ones. - req_state.block_ids = new_block_ids - - req_index = self.input_batch.req_id_to_index.get(req_id) - if req_index is None: - # The request is not in the persistent batch. - # The request was either preempted and resumed later, or was not - # scheduled in the previous step and needs to be added again. - req_ids_to_add.append(req_id) - continue - - # Update the persistent batch. - self.input_batch.num_computed_tokens_cpu[req_index] = num_computed_tokens - if new_block_ids is not None: - self.input_batch.block_table.append_row(new_block_ids, req_index) - - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) - for req_id in req_ids_to_add: - req_state = self.requests[req_id] - # Fill the empty index or append to the end - req_index = removed_req_indices.pop() if removed_req_indices else None - self.input_batch.add_request(req_state, req_index) - - # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) - - return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 - - def get_model(self) -> nn.Module: - return self.model - - def get_supported_generation_tasks(self) -> list[GenerationTask]: - model = self.get_model() - supported_tasks = list[GenerationTask]() - - if is_text_generation_model(model): - supported_tasks.append("generate") - - if supports_transcription(model): - if model.supports_transcription_only: - return ["transcription"] - - supported_tasks.append("transcription") - - return supported_tasks - - def get_supported_pooling_tasks(self) -> list[PoolingTask]: - model = self.get_model() - if not is_pooling_model(model): - return [] - - return list(model.pooler.get_supported_tasks()) - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - tasks = list[SupportedTask]() - - if self.model_config.runner_type == "generate": - tasks.extend(self.get_supported_generation_tasks()) - if self.model_config.runner_type == "pooling": - tasks.extend(self.get_supported_pooling_tasks()) - - return tuple(tasks) - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - """ - Generates the KVCacheSpec by parsing the kv cache format from each - Attention module in the static forward context. - Returns: - KVCacheSpec: A dictionary mapping layer names to their KV cache - format. Layers that do not need KV cache are not included. - """ - - layers = get_layers_from_vllm_config( - self.vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] - ) - block_size = self.vllm_config.cache_config.block_size - cache_dtype_str = self.vllm_config.cache_config.cache_dtype - - kv_cache_spec: dict[str, KVCacheSpec] = {} - for layer_name, attn_module in layers.items(): - # Classic Attention path - if isinstance(attn_module, Attention): - if ( - kv_tgt_layer := attn_module.kv_sharing_target_layer_name - ) is not None: - # The layer doesn't need its own KV cache and will use that of - # the target layer. We skip creating a KVCacheSpec for it, so - # that KV cache management logic will act as this layer does - # not exist, and doesn't allocate KV cache for the layer. This - # enables the memory saving of cross-layer kv sharing, allowing - # a given amount of memory to accommodate longer context lengths - # or enable more requests to be processed simultaneously. - self.shared_kv_cache_layers[layer_name] = kv_tgt_layer - continue - - if attn_module.attn_type == AttentionType.DECODER: - if isinstance(attn_module, ChunkedLocalAttention): - logger.warning_once( - "Using irope in Pallas is not supported yet, it " - "will fall back to global attention for long context." - ) - if attn_module.sliding_window is not None: - kv_cache_spec[layer_name] = SlidingWindowSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - ) - else: - kv_cache_spec[layer_name] = FullAttentionSpec( - block_size=block_size, - num_kv_heads=attn_module.num_kv_heads, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - ) - elif attn_module.attn_type in ( - AttentionType.ENCODER, - AttentionType.ENCODER_ONLY, - ): - # encoder-only attention does not need KV cache. - continue - elif attn_module.attn_type == AttentionType.ENCODER_DECODER: - raise NotImplementedError - else: - raise ValueError(f"Unknown attention type: {attn_module.attn_type}") - # MLAAttention path - elif isinstance(attn_module, MLAAttention): - if layer_name in kv_cache_spec: - continue - kv_cache_spec[layer_name] = MLAAttentionSpec( - block_size=block_size, - num_kv_heads=1, - head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - cache_dtype_str=cache_dtype_str, - ) - else: - continue - - return kv_cache_spec - - def _get_slot_mapping_metadata( - self, num_reqs, num_scheduled_tokens_per_req - ) -> np.ndarray: - """ - Computes metadata for mapping slots to blocks in the key-value (KV) - cache for a batch of requests. - - This function determines, for each request in the batch, how the - scheduled tokens are distributed across memory blocks, and generates - metadata needed to map slices of tokens to their corresponding positions - in the KV cache. - - Args: - num_reqs (int): Number of requests in the current batch. - num_scheduled_tokens_per_req (int or np.ndarray): Number of tokens - to be scheduled for each request. - - Returns: - np.ndarray: A 2D array of shape (total_block_len, 3), where each row - contains: - - kv_cache_start_index (int): The starting index in the KV cache - for the corresponding slice. - - new_kv_start_index (int): The starting index in the new KV - cache for the corresponding slice. - - slice_len (int): The length of the slice. - """ - slices_start = self.input_batch.num_computed_tokens_cpu[:num_reqs] - slices_end = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_scheduled_tokens_per_req - ) - local_block_start_idx = slices_start // self.block_size - local_block_end_idx = (slices_end - 1) // self.block_size - no_repeat_req_indices = self.arange_np[:num_reqs] - global_block_start_idx = ( - no_repeat_req_indices * self.max_num_blocks_per_req + local_block_start_idx - ) - block_lens = local_block_end_idx - local_block_start_idx + 1 - global_block_start_idx = np.repeat(global_block_start_idx, block_lens) - slice_arange = np.concatenate([self.arange_np[:n] for n in block_lens]) - global_block_indices = global_block_start_idx + slice_arange - block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor() - block_numbers = block_table_cpu.flatten()[global_block_indices].numpy() - total_block_len = np.sum(block_lens) - slot_mapping_slices = np.repeat( - np.array([[0, self.block_size]], dtype=np.int32), total_block_len, axis=0 - ) - cu_block_lens = np.zeros(len(block_lens) + 1, dtype=np.int32) - np.cumsum(block_lens, out=cu_block_lens[1:]) - for req_idx in range(num_reqs): - slot_mapping_slices[cu_block_lens[req_idx]][0] = ( - slices_start[req_idx] % self.block_size - ) - slot_mapping_slices[cu_block_lens[req_idx + 1] - 1][1] = ( - slices_end[req_idx] - 1 - ) % self.block_size + 1 - slice_lens = slot_mapping_slices[:, 1] - slot_mapping_slices[:, 0] - cu_slices_lens = np.zeros(len(slice_lens) + 1, dtype=np.int32) - np.cumsum(slice_lens, out=cu_slices_lens[1:]) - kv_cache_start_indices = slot_mapping_slices[:, 0] + ( - block_numbers * self.block_size - ) - new_kv_start_indices = cu_slices_lens[:-1] - slot_mapping_metadata = np.stack( - [kv_cache_start_indices, new_kv_start_indices, slice_lens], axis=1 - ) - return slot_mapping_metadata - - def _prepare_inputs(self, scheduler_output: "SchedulerOutput", start_index: int): - assert scheduler_output.total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs - assert num_reqs > 0 - assert start_index < num_reqs - - # Get the number of scheduled tokens for each request. - use_max_model_len = self.most_model_len is None - num_scheduled_tokens_per_req = [] - max_num_scheduled_tokens_all_reqs = 0 - end_index = start_index - - # Use either most_model_len or max_model_len depending on request size. - for i in range(start_index, num_reqs): - req_id = self.input_batch.req_ids[i] - assert req_id is not None - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if ( - not use_max_model_len - and self.most_model_len is not None - and num_tokens > self.most_model_len - ): - use_max_model_len = True - num_scheduled_tokens_per_req.append(num_tokens) - if use_max_model_len: - if len(num_scheduled_tokens_per_req) > self.num_reqs_max_model_len: - num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ - : self.num_reqs_max_model_len - ] - end_index = start_index + self.num_reqs_max_model_len - else: - end_index = num_reqs - else: - assert self.num_reqs_most_model_len is not None - if len(num_scheduled_tokens_per_req) > self.num_reqs_most_model_len: - num_scheduled_tokens_per_req = num_scheduled_tokens_per_req[ - : self.num_reqs_most_model_len - ] - end_index = start_index + self.num_reqs_most_model_len - else: - end_index = num_reqs - max_num_scheduled_tokens_all_reqs = max(num_scheduled_tokens_per_req) - num_scheduled_tokens_per_req = np.array( - num_scheduled_tokens_per_req, dtype=np.int32 - ) - total_num_scheduled_tokens = sum(num_scheduled_tokens_per_req) - assert max_num_scheduled_tokens_all_reqs > 0 - - num_reqs = len(num_scheduled_tokens_per_req) - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - # For each scheduled token, what are the corresponding req index. - req_indices = np.repeat(self.arange_np[:num_reqs], num_scheduled_tokens_per_req) - - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # For each scheduled token, what is its position in corresponding req. - arange = np.concatenate( - [self.arange_np[:n] for n in num_scheduled_tokens_per_req] - ) - - # Get positions. - positions_np = self.positions_np[:total_num_scheduled_tokens] - np.add( - self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np, - ) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = ( - positions_np + req_indices * self.input_batch.token_ids_cpu.shape[1] - ) - - # NOTE(woosuk): We use torch.index_select instead of np.take here - # because torch.index_select is much faster than np.take for large - # tensors. - torch.index_select( - self.input_batch.token_ids_cpu_tensor.flatten(), - 0, - torch.from_numpy(token_indices), - out=self.input_ids_cpu[:total_num_scheduled_tokens], - ) - - # Prepare the attention metadata. - self.query_start_loc_np[0] = 0 - np.cumsum( - num_scheduled_tokens_per_req, out=self.query_start_loc_np[1 : num_reqs + 1] - ) - self.query_start_loc_np[num_reqs + 1 :] = 1 - - self.seq_lens_np[:num_reqs] = ( - self.input_batch.num_computed_tokens_cpu[:num_reqs] - + num_scheduled_tokens_per_req - ) - - # Do the padding and copy the tensors to the TPU. - padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens - ) - # Zero out to avoid spurious values from prev iteration (last cp chunk) - self.input_ids_cpu[ - total_num_scheduled_tokens:padded_total_num_scheduled_tokens - ] = 0 - self.input_ids = self.input_ids_cpu[:padded_total_num_scheduled_tokens].to( - self.device - ) - self.position_ids = self.positions_cpu[:padded_total_num_scheduled_tokens].to( - self.device - ) - if use_max_model_len: - block_tables = self.block_table_cpu[ - : self.num_reqs_max_model_len, : self.max_num_blocks_per_req - ] - block_tables[:num_reqs, : self.max_num_blocks_per_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs] - ) - query_start_loc = self.query_start_loc_cpu[ - : self.num_reqs_max_model_len + 1 - ].to(self.device) - seq_lens = self.seq_lens_cpu[: self.num_reqs_max_model_len].to(self.device) - else: - assert self.num_reqs_most_model_len is not None - block_tables = self.block_table_cpu[ - : self.num_reqs_most_model_len, : self.num_blocks_per_most_len_req - ] - block_tables[:num_reqs, : self.num_blocks_per_most_len_req] = ( - self.input_batch.block_table[0].get_cpu_tensor()[ - :num_reqs, : self.num_blocks_per_most_len_req - ] - ) - query_start_loc = self.query_start_loc_cpu[ - : self.num_reqs_most_model_len + 1 - ].to(self.device) - seq_lens = self.seq_lens_cpu[: self.num_reqs_most_model_len].to(self.device) - block_tables = block_tables.to(self.device) - - # Calculate the slot mapping - slot_mapping_metadata = self._get_slot_mapping_metadata( - num_reqs, num_scheduled_tokens_per_req - ) - num_kv_update_slices = slot_mapping_metadata.shape[0] - padded_num_slices = _get_padded_num_kv_cache_update_slices( - padded_total_num_scheduled_tokens, self.max_num_reqs, self.block_size - ) - slot_mapping_metadata = np.pad( - slot_mapping_metadata, - [[0, padded_num_slices - len(slot_mapping_metadata)], [0, 0]], - constant_values=0, - ) - slot_mapping_metadata = np.transpose(slot_mapping_metadata) - slot_mapping_metadata = torch.tensor(slot_mapping_metadata, device=self.device) - - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += ( - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - ) - - self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping_metadata, - block_tables=block_tables, - context_lens=seq_lens, - query_start_loc=query_start_loc, - num_seqs=torch.tensor([num_reqs], dtype=torch.int32, device=self.device), - num_kv_update_slices=torch.tensor( - [num_kv_update_slices], dtype=torch.int32, device=self.device - ), - num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, - ) - # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial - # request in the batch. While we should not sample any token from this - # partial request, we do so for simplicity. We will ignore the sampled - # token from the partial request. - # TODO: Support prompt logprobs. - padded_num_reqs = _get_padded_num_reqs_with_upper_limit( - num_reqs, self.max_num_reqs - ) - # Indices at which we sample (positions of last token in the sequence). - # Padded to avoid recompiling when `num_reqs` varies. - logits_indices = self.query_start_loc_cpu[1 : padded_num_reqs + 1] - 1 - logits_indices = logits_indices.to(self.device) - - if self.lora_config is not None: - # We need to respect padding when activating LoRA adapters - padded_num_scheduled_tokens_per_req = np.copy( - num_scheduled_tokens_per_req - ) # Copying to avoid accidental state corruption bugs - padded_num_scheduled_tokens_per_req[-1] += ( - padded_total_num_scheduled_tokens - total_num_scheduled_tokens - ) - - self.set_active_loras(self.input_batch, padded_num_scheduled_tokens_per_req) - - layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata for layer_name in layer_names - } - return ( - per_layer_attn_metadata, - logits_indices, - padded_num_reqs, - num_reqs, - end_index, - ) - - def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): - scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs - if not scheduled_encoder_inputs: - return - - # Batch the multi-modal inputs. - mm_kwargs = list[MultiModalKwargsItem]() - # List of tuple (mm_hash, pos_info) - mm_hashes_pos = list[tuple[str, PlaceholderRange]]() - for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): - req_state = self.requests[req_id] - - for mm_input_id in encoder_input_ids: - mm_feature = req_state.mm_features[mm_input_id] - if mm_feature.data is None: - continue - mm_hash = mm_feature.identifier - mm_kwargs.append(mm_feature.data) - mm_hashes_pos.append((mm_hash, mm_feature.mm_position)) - - # Batch mm inputs as much as we can: if a request in the batch has - # multiple modalities or a different modality than the previous one, - # we process it separately to preserve item order. - # FIXME(ywang96): This is a hacky way to deal with multiple modalities - # in the same batch while still being able to benefit from batching - # multimodal inputs. The proper solution should be reordering the - # encoder outputs. - model = cast(SupportsMultiModal, self.model) - encoder_outputs = [] - for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality( - mm_kwargs, - device=self.device, - pin_memory=self.pin_memory, - ): - # Run the encoder. - # `curr_group_outputs` is either of the following: - # 1. A tensor of shape (num_items, feature_size, hidden_size) - # in case feature_size is fixed across all multimodal items. - # 2. A list or tuple (length: num_items) of tensors, each of shape - # (feature_size, hidden_size) in case the feature size is dynamic - # depending on the input multimodal items. - torch_xla.sync(wait=False) - curr_group_outputs = model.embed_multimodal(**mm_kwargs_group) - torch_xla.sync(wait=False) - - sanity_check_mm_encoder_outputs( - curr_group_outputs, - expected_num_items=num_items, - ) - - if isinstance(curr_group_outputs, torch.Tensor): - encoder_outputs.append(curr_group_outputs) - else: - assert isinstance(curr_group_outputs, (list, tuple)) - for output in curr_group_outputs: - encoder_outputs.append(output) - - # Cache the encoder outputs. - # NOTE (NickLucche) here we diverge from logic in other runners, as we - # assume to only have whole mm items to process. Hence we avoid the - # intrinsic dynamism that `scatter_mm_placeholders` introduces. - for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - assert pos_info.is_embed is None, ( - "Expected all positions to be contiguous and embeddings." - ) - self.encoder_cache[mm_hash] = output - - def _gather_mm_embeddings( - self, - scheduler_output: "SchedulerOutput", - ) -> tuple[list[torch.Tensor], torch.Tensor]: - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens - padded_total_num_scheduled_tokens = _get_padded_token_len( - self.num_tokens_paddings, total_num_scheduled_tokens - ) - - is_mm_embed = self.is_mm_embed_cpu - is_mm_embed[:padded_total_num_scheduled_tokens] = False - mm_embeds = list[torch.Tensor]() - req_start_idx = 0 - - for req_id in self.input_batch.req_ids: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] - req_state = self.requests[req_id] - num_computed_tokens = req_state.num_computed_tokens - - # TODO unroll loop and assume/enforce --disable_chunked_mm_input - # NOTE (NickLucche) here we diverge from logic in other runners, as - # we assume to only have whole mm items to process. Hence we avoid - # the intrinsic dynamism that `gather_mm_placeholders` introduces. - for mm_feature in req_state.mm_features: - pos_info = mm_feature.mm_position - start_pos = pos_info.offset - num_encoder_tokens = pos_info.length - - # The encoder output is needed if the two ranges overlap: - # [num_computed_tokens, - # num_computed_tokens + num_scheduled_tokens) and - # [start_pos, start_pos + num_encoder_tokens) - if start_pos >= num_computed_tokens + num_scheduled_tokens: - # The encoder output is not needed in this step. - break - if start_pos + num_encoder_tokens <= num_computed_tokens: - # The encoder output is already processed and stored - # in the decoder's KV cache. - continue - - start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens, - ) - assert start_idx < end_idx - - mm_hash = mm_feature.identifier - encoder_output = self.encoder_cache.get(mm_hash, None) - assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." - - assert pos_info.is_embed is None, ( - "Expected all positions to be contiguous and embeddings." - ) - - req_start_pos = req_start_idx + start_pos - num_computed_tokens - is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = True - - # Only whole mm items are processed - mm_embeds.append(encoder_output) - - req_start_idx += num_scheduled_tokens - - is_mm_embed = is_mm_embed[:padded_total_num_scheduled_tokens].to(self.device) - - return mm_embeds, is_mm_embed - - def _get_model_inputs( - self, - input_ids: torch.Tensor, - mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None, - ): - if self.supports_mm_inputs: - mm_embeds, is_mm_embed = mm_embed_inputs or (None, None) - - # NOTE(woosuk): To unify token ids and soft tokens (vision - # embeddings), we always use embeddings (rather than token ids) - # as input to the multimodal model, even when the input is text. - inputs_embeds = self.model.embed_input_ids( - input_ids, - multimodal_embeddings=mm_embeds, - is_multimodal=is_mm_embed, - ) - - return None, inputs_embeds - else: - # For text-only models, we use token ids as input. - # While it is possible to use embeddings as input just like the - # multimodal models, it is not desirable for performance since - # then the embedding layer is not included in the CUDA graph. - return input_ids, None - - @torch.no_grad() - def execute_model( - self, - scheduler_output: "SchedulerOutput", - intermediate_tensors: IntermediateTensors | None = None, - ) -> ModelRunnerOutput | None: - if self.scheduler_output is not None: - raise RuntimeError( - "State error: sample_tokens() must be called " - "after execute_model() returns None." - ) - # Update cached state - self._update_states(scheduler_output) - if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT - - return self.kv_connector_no_forward(scheduler_output, self.vllm_config) - - mm_embed_inputs = None - if self.supports_mm_inputs: - # Run the multimodal encoder if any. - self._execute_mm_encoder(scheduler_output) - mm_embed_inputs = self._gather_mm_embeddings(scheduler_output) - - torch_xla.sync(wait=False) - - self.scheduler_output = scheduler_output - self.mm_embed_inputs = mm_embed_inputs - return None - - @torch.no_grad() - def sample_tokens( - self, grammar_output: "GrammarOutput | None" - ) -> ModelRunnerOutput: - if self.scheduler_output is None: - # Nothing to do (PP non-final rank case), output isn't used. - return None # type: ignore[return-value] - scheduler_output = self.scheduler_output - mm_embed_inputs = self.mm_embed_inputs - self.scheduler_output = None - self.mm_embed_inputs = None - - # Prepare inputs, the requests might be split into multiple - # executions, combine the result of each execution. - start_index = 0 - combined_selected_tokens: list[torch.Tensor] = [] - combined_logprobs: list[LogprobsLists] = [] - - # NOTE: setup current batch's metadata for kv connector. - # Currently, only verified with NixlConnector - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - - while start_index < self.input_batch.num_reqs: - attn_metadata, logits_indices, padded_num_reqs, num_reqs, end_index = ( - self._prepare_inputs(scheduler_output, start_index) - ) - input_ids, inputs_embeds = self._get_model_inputs( - self.input_ids, mm_embed_inputs - ) - torch_xla.sync(wait=False) - # Run the decoder - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=scheduler_output.total_num_scheduled_tokens, - ): - hidden_states = self.model( - input_ids=input_ids, - positions=self.position_ids, - inputs_embeds=inputs_embeds, - ) - hidden_states = self.select_hidden_states(hidden_states, logits_indices) - logits = self.compute_logits(hidden_states) - tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, padded_num_reqs, self.device - ) - if grammar_output is not None: - require_struct_decoding, grammar_bitmask_padded, arange = ( - self.prepare_structured_decoding_input(logits, grammar_output) - ) - logits = self.structured_decode( - require_struct_decoding, grammar_bitmask_padded, logits, arange - ) - selected_token_ids = self.sample_from_logits_func( - logits, tpu_sampling_metadata - ) - # NOTE (NickLucche) Use the original logits (before any penalties or - # temperature scaling) for the top-k logprobs. We can't enforce it - # due to recompilations outside torch.compiled code, so just make - # sure `sample_from_logits` does not modify the logits in-place. - logprobs = ( - self.gather_logprobs(logits, selected_token_ids) - if tpu_sampling_metadata.logprobs - else None - ) - - # Remove padding on cpu and keep dynamic op outside of xla graph. - selected_token_ids = selected_token_ids.cpu()[:num_reqs] - - combined_selected_tokens.append(selected_token_ids) - if tpu_sampling_metadata.logprobs: - combined_logprobs.append(logprobs.tolists()) - - start_index = end_index - - # NOTE: current kv load and save get h2d/d2h copies involved. - # Those copies are blocking. Once they become async., kv_save - # should be called right after each single forward pass, - # instead of the forwards of the entire input batch. - self.maybe_wait_for_kv_save() - finished_sending, finished_recving = self.get_finished_kv_transfers( - scheduler_output - ) - - selected_token_ids = torch.cat(combined_selected_tokens, dim=0) - if tpu_sampling_metadata.logprobs: - - def concat_lists(input_lists): - result = [] - for input_list in input_lists: - result.extend(input_list) - return result - - logprobs_lists = LogprobsLists( - logprob_token_ids=concat_lists( - [lp.logprob_token_ids for lp in combined_logprobs] - ), - logprobs=concat_lists([lp.logprobs for lp in combined_logprobs]), - sampled_token_ranks=concat_lists( - [lp.sampled_token_ranks for lp in combined_logprobs] - ), - ) - else: - logprobs_lists = None - - # Update the cache state concurrently. Code above will not block until - # we use `selected_token_ids`. Add mark_step if post-processing changes - request_seq_lens: list[tuple[int, CachedRequestState, int]] = [] - discard_sampled_tokens_req_indices = [] - num_reqs = self.input_batch.num_reqs - for i, req_id in zip(range(num_reqs), self.input_batch.req_ids): - assert req_id is not None - req_state = self.requests[req_id] - seq_len = ( - req_state.num_computed_tokens - + scheduler_output.num_scheduled_tokens[req_id] - ) - if seq_len >= req_state.num_tokens: - request_seq_lens.append((i, req_state, seq_len)) - else: - # Ignore the sampled token from the partial request. - # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) - - # Record the index of the request that should not be sampled, - # so that we could clear the sampled tokens before returning. - discard_sampled_tokens_req_indices.append(i) - - assert all( - req_id is not None for req_id in self.input_batch.req_ids[:num_reqs] - ), "req_ids contains None" - req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) - - prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {} - for req_id in self.input_batch.req_ids[:num_reqs]: - prompt_logprobs_dict[req_id] = None - - max_gen_len = selected_token_ids.shape[-1] - if max_gen_len == 1: - valid_sampled_token_ids = selected_token_ids.tolist() - - # Mask out the sampled tokens that should not be sampled. - # TODO: Keep in sync with gpu_model_runner.py, in particular - # the "else" case here - for i in discard_sampled_tokens_req_indices: - valid_sampled_token_ids[i].clear() - - # Append sampled tokens - for i, req_state, seq_len in request_seq_lens: - token_id = valid_sampled_token_ids[i][0] - self.input_batch.token_ids_cpu[i, seq_len] = token_id - req_state.output_token_ids.append(token_id) - self.input_batch.num_tokens_no_spec[i] += 1 - - else: - valid_mask = selected_token_ids != INVALID_TOKEN_ID - gen_lens = valid_mask.sum(dim=1).tolist() - valid_sampled_token_ids = [ - seq.tolist() for seq in selected_token_ids[valid_mask].split(gen_lens) - ] - self.input_batch.num_tokens_no_spec[:num_reqs] += gen_lens - for i, req_state, seq_len in request_seq_lens: - target_slice = slice(seq_len - gen_lens[i] + 1, seq_len + 1) - self.input_batch.token_ids_cpu[i, target_slice] = ( - valid_sampled_token_ids[i] - ) - req_state.output_token_ids.extend(valid_sampled_token_ids[i]) - - kv_connector_output = ( - None - if (finished_sending is None and finished_recving is None) - else KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving, - ) - ) - - model_runner_output = ModelRunnerOutput( - req_ids=req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=valid_sampled_token_ids, - logprobs=logprobs_lists, - prompt_logprobs_dict=prompt_logprobs_dict, - pooler_output=[], - kv_connector_output=kv_connector_output, - ) - - # Check there are no new graphs compiled - all the graphs should be - # captured and compiled during warm up. - self._verify_num_xla_graphs("execute_model") - - return model_runner_output - - def update_config(self, overrides: dict[str, Any]) -> None: - # TODO: TPU config may need extra validation - # https://github.com/vllm-project/vllm/pull/20095#discussion_r2201497754 - allowed_config_names = {"load_config", "model_config"} - for config_name, config_overrides in overrides.items(): - assert config_name in allowed_config_names, ( - f"Config `{config_name}` not supported. " - f"Allowed configs: {allowed_config_names}" - ) - config = getattr(self, config_name) - new_config = update_config(config, config_overrides) - setattr(self, config_name, new_config) - - def load_model(self) -> None: - self.device = self.device_config.device - - # NOTE(woosuk): While the executor assigns the TP ranks to the worker - # process, the ranks can be different from the ranks internally assigned - # by the xm runtime. Therefore, there is a mismatch in the rank - # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. - # This is not a problem in linear layers because all-reduce is - # rank-agnostic. However, it matters for all-gather as the ranks - # determine the order of concatenating the output tensors. - # As a workaround, we use the xm's rank assignment only when loading - # the embedding weights. - xm_tp_rank = xr.global_ordinal() - with patch( - "vllm.model_executor.layers.vocab_parallel_embedding." - "get_tensor_model_parallel_rank", - return_value=xm_tp_rank, - ): - try: - if self.use_spmd: - tpu_loader = TPUModelLoader( - load_config=self.vllm_config.load_config - ) - model = tpu_loader.load_model( - vllm_config=self.vllm_config, - model_config=self.vllm_config.model_config, - mesh=self.mesh, - ) - else: - model_loader = get_model_loader(self.load_config) - logger.info("Loading model from scratch...") - model = model_loader.load_model( - vllm_config=self.vllm_config, model_config=self.model_config - ) - except RuntimeError as e: - raise RuntimeError( - f"Unable to load model, a likely reason is the model is " - "too large for the current device's HBM memory. " - "Consider switching to a smaller model " - "or sharding the weights on more chips. " - f"See the detailed error: {e}" - ) from e - if self.lora_config is not None: - model = self.load_lora_model(model, self.vllm_config, self.device) - replace_set_lora(model) - - # Sync all pending XLA execution during model initialization and weight - # loading. - torch_xla.sync(wait=False) - xm.wait_device_ops() - if not hasattr(self, "model"): - self.model = model - self.sampler = TPUSampler() - - def reload_weights(self) -> None: - assert getattr(self, "model", None) is not None, ( - "Cannot reload weights before model is loaded." - ) - model_loader = get_model_loader(self.load_config) - logger.info("Reloading weights inplace...") - model_loader.load_weights(self.model, model_config=self.model_config) - - @torch.no_grad() - def _dummy_run(self, num_tokens: int, num_reqs: int, num_blocks: int) -> None: - if self.supports_mm_inputs: - input_ids = None - inputs_embeds = torch.zeros( - (num_tokens, self.inputs_embeds_size), - dtype=self.dtype, - device=self.device, - ) - else: - input_ids = torch.zeros((num_tokens), dtype=torch.int32).to(self.device) - inputs_embeds = None - actual_num_reqs = min(num_tokens, num_reqs) - position_ids = torch.zeros(num_tokens, dtype=torch.int32).to(self.device) - padded_num_slices = _get_padded_num_kv_cache_update_slices( - num_tokens, self.max_num_reqs, self.block_size - ) - num_kv_update_slices = torch.tensor([padded_num_slices], dtype=torch.int32).to( - self.device - ) - slot_mapping = torch.zeros((3, padded_num_slices), dtype=torch.int32).to( - self.device - ) - block_tables = torch.zeros((num_reqs, num_blocks), dtype=torch.int32).to( - self.device - ) - query_lens = [1] * num_reqs - query_start_loc = torch.cumsum( - torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ).to(self.device) - context_lens = torch.ones((num_reqs,), dtype=torch.int32).to(self.device) - num_seqs = torch.tensor([actual_num_reqs], dtype=torch.int32).to(self.device) - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=self._num_slices_per_kv_cache_update_block, - ) - - if self.supports_mm_inputs: - torch._dynamo.mark_dynamic(inputs_embeds, 0) - else: - torch._dynamo.mark_dynamic(input_ids, 0) - torch._dynamo.mark_dynamic(position_ids, 0) - torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) - torch._dynamo.mark_dynamic(attn_metadata.block_tables, (0, 1)) - torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) - torch._dynamo.mark_dynamic(attn_metadata.query_start_loc, 0) - - layer_names = get_layers_from_vllm_config(self.vllm_config, Attention).keys() - per_layer_attn_metadata = { - layer_name: attn_metadata for layer_name in layer_names - } - - with ( - self.maybe_select_dummy_loras( - self.lora_config, np.array([num_tokens], dtype=np.int32) - ), - set_forward_context(per_layer_attn_metadata, self.vllm_config, 0), - ): - out = self.model( - input_ids=input_ids, positions=position_ids, inputs_embeds=inputs_embeds - ) - self._hidden_states_dtype = out.dtype - - def _set_active_loras( - self, prompt_lora_mapping, token_lora_mapping, lora_requests - ) -> None: - torch_xla.sync(wait=False) # Captures input updates - super()._set_active_loras( - prompt_lora_mapping, token_lora_mapping, lora_requests - ) - torch_xla.sync(wait=False) # Captures metadata updates - - def _precompile_mm_encoder(self) -> None: - if not self.supports_mm_inputs: - return - - # Pre-compile MM encoder for all supported data modalities. - hf_config = self.vllm_config.model_config.hf_config - - mm_budget = self.mm_budget - assert mm_budget is not None - - max_items_per_seq_by_modality = mm_budget.max_items_per_batch_by_modality # noqa: E501 - - for mode, max_items_per_seq in max_items_per_seq_by_modality.items(): - logger.info( - "Compiling Multimodal %s Encoder with different input shapes.", mode - ) - start = time.perf_counter() - # No padding for MM encoder just yet. - for num_items in range(1, max_items_per_seq + 1): - logger.info(" -- mode: %s items: %d", mode, num_items) - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - mode, - num_items, - ) - # Run multimodal encoder. - torch_xla.sync(wait=False) - mm_embeds = self.model.embed_multimodal(**batched_dummy_mm_inputs) - torch_xla.sync(wait=False) - num_patches = mm_embeds[0].shape[0] - items_size = num_patches * num_items - - # NOTE (NickLucche) pre-compile `embed_input_ids` when mm - # embeddings are present. We assume `--disable-mm-chunked`, - # hence only whole items can be scheduled. This implies we just - # need to compile when `num_items` fit the (padded) `input_ids` - for num_tokens in self.num_tokens_paddings: - if num_tokens >= items_size: - # XLA Workaround: if torch.zeros(..device) is used, XLA - # compiles a scalar+expansion op, which won't match - # the graph generated at runtime. CPU->TPU must be used - placeholders_ids = torch.zeros( - num_tokens, dtype=torch.int32, device="cpu" - ) - # Align placeholders and actual num mm_embeddings. - placeholders_ids[:items_size] = hf_config.image_token_index - - placeholders_ids = placeholders_ids.to(self.device) - - mm_mask = torch.tensor([False] * num_tokens) - mm_mask[:items_size] = True - mm_mask = mm_mask.to(self.device) - # Assign outputs or the graph will be cut short. - a, b = self._get_model_inputs( - placeholders_ids, - mm_embed_inputs=([mm_embeds], mm_mask), - ) - assert a is None - torch_xla.sync(wait=False) - - # Pre-compile `embed_input_ids` when mm_embeddings are not - # present. Chunk is only made of text, no mm_placeholders. - for num_tokens in self.num_tokens_paddings: - placeholders_ids = torch.zeros( - num_tokens, dtype=torch.int32, device="cpu" - ) - placeholders_ids = placeholders_ids.to(self.device) - a, b = self._get_model_inputs( - placeholders_ids, - mm_embed_inputs=None, - ) - assert a is None - torch_xla.sync(wait=False) - - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal %s Encoder compilation finished in in %.2f [secs].", - mode, - end - start, - ) - - def _precompile_backbone(self) -> None: - logger.info("Compiling the model with different input shapes.") - start = time.perf_counter() - for num_tokens in self.num_tokens_paddings: - logger.info(" -- num_tokens: %d", num_tokens) - self._dummy_run( - num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req - ) - if self.most_model_len is not None: - self._dummy_run( - num_tokens, - self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req, - ) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("model backbone") - - def _precompile_select_hidden_states(self) -> None: - # Compile hidden state selection function for bucketed - # n_tokens x max_num_reqs. Graph is really small so this is fine. - logger.info("Compiling select_hidden_states with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_tokens in self.num_tokens_paddings: - dummy_hidden = torch.zeros( - (num_tokens, hsize), device=self.device, dtype=self._hidden_states_dtype - ) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - for num_reqs in self.num_reqs_paddings: - indices = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) - torch._dynamo.mark_dynamic(indices, 0) - self.select_hidden_states(dummy_hidden, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, num_reqs) - # Requests can't be more than tokens. But do compile for the - # next bigger value in case num_tokens uses bucketed padding. - if num_reqs >= min(num_tokens, self.max_num_reqs): - break - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("select_hidden_states") - - def _precompile_compute_logits(self) -> None: - logger.info("Compiling compute_logits with different input shapes.") - start = time.perf_counter() - hsize = self.model_config.get_hidden_size() - for num_reqs in self.num_reqs_paddings: - dummy_hidden = torch.zeros( - (num_reqs, hsize), device=self.device, dtype=self._hidden_states_dtype - ) - torch._dynamo.mark_dynamic(dummy_hidden, 0) - self.compute_logits(dummy_hidden) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("compute_logits") - - def _precompile_structured_decoding(self) -> None: - logger.info("Compiling structured_decoding with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - dummy_require_struct_decoding = self.require_structured_out_cpu[ - :num_reqs - ].to(self.device) - dummy_grammar_bitmask = self.grammar_bitmask_cpu[:num_reqs].to(self.device) - # The first dimension of the above 3 dummy tensors cannot be - # mark_dynamic because some operations in structured_decode require - # them to be static. - arange = self.structured_decode_arange.to(self.device) - self.structured_decode( - dummy_require_struct_decoding, - dummy_grammar_bitmask, - dummy_logits, - arange, - ) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("structured_decoding") - - def _precompile_sample_from_logits(self) -> None: - logger.info("Compiling sample_from_logits with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - # The first dimension of dummy_logits cannot be mark_dynamic - # because some operations in the sampler require it to be static. - for all_greedy in [False, True]: - generate_params_if_all_greedy = not all_greedy - sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch( - self.input_batch, - num_reqs, - self.device, - generate_params_if_all_greedy, - ) - sampling_metadata.all_greedy = all_greedy - with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32) - ): - self.sample_from_logits_func(dummy_logits, sampling_metadata) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("sample_from_logits") - - def _precompile_gather_logprobs(self) -> None: - logger.info("Compiling gather_logprobs with different input shapes.") - start = time.perf_counter() - for num_reqs in self.num_reqs_paddings: - dummy_logits = torch.zeros( - (num_reqs, self.vocab_size), - device=self.device, - dtype=self._hidden_states_dtype, - ) - dummy_tokens = torch.zeros((num_reqs, 1), dtype=torch.int64).to(self.device) - with self.maybe_select_dummy_loras( - self.lora_config, np.array([num_reqs], dtype=np.int32) - ): - self.gather_logprobs(dummy_logits, dummy_tokens) - logger.info(" -- num_seqs: %d", num_reqs) - xm.wait_device_ops() - end = time.perf_counter() - logger.info("Compilation finished in %.2f [secs].", end - start) - self._update_num_xla_graphs("gather_logprobs") - - def capture_model(self) -> None: - """ - Precompile all the subgraphs with possible input shapes. - """ - with self.maybe_setup_dummy_loras(self.lora_config): - self._precompile_mm_encoder() - self._precompile_backbone() - self._precompile_select_hidden_states() - self._precompile_compute_logits() - self._precompile_structured_decoding() - self._precompile_sample_from_logits() - self._precompile_gather_logprobs() - - def profile_run( - self, - num_tokens: int, - ) -> None: - # Profile with multimodal encoder & encoder cache. - if self.supports_mm_inputs: - mm_config = self.model_config.multimodal_config - if mm_config is not None and mm_config.skip_mm_profiling: - logger.info( - "Skipping memory profiling for multimodal encoder and " - "encoder cache." - ) - else: - mm_budget = self.mm_budget - assert mm_budget is not None - - # TODO: handle encoder-decoder models once we support them. - if (encoder_budget := mm_budget.get_encoder_budget()) > 0: - # NOTE: Currently model is profiled with a single non-text - # modality with the max possible input tokens even when - # it supports multiple. - dummy_modality = mm_budget.get_modality_with_max_tokens() - max_mm_items_per_batch = mm_budget.max_items_per_batch_by_modality[ - dummy_modality - ] - - logger.info( - "Encoder cache will be initialized with a budget of " - "%s tokens, and profiled with %s %s items of the " - "maximum feature size.", - encoder_budget, - max_mm_items_per_batch, - dummy_modality, - ) - - # Create dummy batch of multimodal inputs. - batched_dummy_mm_inputs = self._get_mm_dummy_batch( - dummy_modality, - max_mm_items_per_batch, - ) - - # Run multimodal encoder. - # Isolate encoder graph from post-processing to minimize - # impact of recompilation until it's fixed. - start = time.perf_counter() - torch_xla.sync(wait=False) - dummy_encoder_outputs = self.model.embed_multimodal( - **batched_dummy_mm_inputs - ) - torch_xla.sync(wait=False) - xm.wait_device_ops() - end = time.perf_counter() - logger.info( - "Multimodal Encoder profiling finished in %.2f [secs].", - end - start, - ) - - sanity_check_mm_encoder_outputs( - dummy_encoder_outputs, - expected_num_items=max_mm_items_per_batch, - ) - - # Cache the dummy encoder outputs. - self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - - # Trigger compilation for general shape. - self._dummy_run( - num_tokens, self.num_reqs_max_model_len, self.max_num_blocks_per_req - ) - if self.most_model_len is not None: - self._dummy_run( - num_tokens, - self.num_reqs_most_model_len, - self.num_blocks_per_most_len_req, - ) - - torch_xla.sync(wait=False) - xm.wait_device_ops() - self.encoder_cache.clear() - gc.collect() - - def maybe_setup_cross_layer_kv_sharing( - self, - kv_caches: dict[str, torch.Tensor], - kv_cache_config: KVCacheConfig, - ) -> None: - """ - Add layers that re-use KV cache to KV cache group of its target layer. - Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()` - """ - if not self.shared_kv_cache_layers: - # No cross-layer KV sharing, return - return - - add_kv_sharing_layers_to_kv_cache_groups( - self.shared_kv_cache_layers, - kv_cache_config.kv_cache_groups, - ) - - for layer_name, target_layer_name in self.shared_kv_cache_layers.items(): - logger.debug("%s reuses KV cache of %s", layer_name, target_layer_name) - kv_caches[layer_name] = kv_caches[target_layer_name] - - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: - """ - Initialize KV cache based on `kv_cache_config`. - Args: - kv_cache_config: Configuration for the KV cache, including the KV - cache size of each layer - """ - if len(kv_cache_config.kv_cache_groups) > 1: - raise NotImplementedError( - "Hybrid models with more than one KV cache type are not supported yet." - ) - - if ( - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - != self.block_size - ): - self.input_batch = InputBatch( - max_num_reqs=self.max_num_reqs, - max_model_len=self.max_model_len, - max_num_batched_tokens=self.max_num_tokens, - device=self.device, - pin_memory=self.pin_memory, - vocab_size=self.model_config.get_vocab_size(), - block_sizes=[ - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - ], - kernel_block_sizes=[ - kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size - ], - ) - # Verify dtype compatibility between block_table_cpu and input_batch - assert ( - self.block_table_cpu.dtype - == self.input_batch.block_table[0].get_cpu_tensor().dtype - ) - - kv_cache_sizes = {} - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - assert len(kv_cache_tensor.shared_by) == 1, ( - "KV cache tensor shared by multiple layers is not supported in TPU." - ) - kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size - - kv_caches: dict[str, torch.Tensor] = {} - for kv_cache_group in kv_cache_config.kv_cache_groups: - kv_cache_spec = kv_cache_group.kv_cache_spec - for layer_name in kv_cache_group.layer_names: - tensor_size = kv_cache_sizes[layer_name] - assert tensor_size % kv_cache_spec.page_size_bytes == 0 - num_blocks = tensor_size // kv_cache_spec.page_size_bytes # noqa - if isinstance(kv_cache_spec, AttentionSpec): - if self.use_spmd: - num_kv_heads = kv_cache_spec.num_kv_heads - assert self.original_parallel_config is not None - tp_size = self.original_parallel_config.tensor_parallel_size - # TODO: Handle kv cache duplication under SPMD mode. - assert num_kv_heads % tp_size == 0, ( - f"num_kv_heads {num_kv_heads} must be divisible by " - f"tp_size {tp_size} under SPMD mode" - ) - kv_cache_shape = PallasAttentionBackend.get_kv_cache_shape( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - ) - dtype = kv_cache_spec.dtype - - tpu_kv_cache = torch.zeros(kv_cache_shape, dtype=dtype).to( - self.device - ) - - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError - - # Set up cross-layer KV cache sharing if needed - self.maybe_setup_cross_layer_kv_sharing(kv_caches, kv_cache_config) - - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - self.kv_caches, - ) - - if self.use_spmd: - # Shard KV Cache - for cache in self.kv_caches: - xs.mark_sharding(cache, self.mesh, (None, "x", None, None)) - - if has_kv_transfer_group(): - get_kv_transfer_group().register_kv_caches(kv_caches) - get_kv_transfer_group().set_host_xfer_buffer_ops(copy_kv_blocks) - - def reset_dynamo_cache(self): - # NOTE: We check `is_multimodal_model` instead of `supports_mm_inputs` - # since the compiled model object of the language backbone of a - # multimodal model needs to be extracted via `get_language_model`. - if self.model_config.is_multimodal_model: - compiled_model = self.model.get_language_model().model - else: - compiled_model = self.model.model - if isinstance(compiled_model, TorchCompileWithNoGuardsWrapper): - logger.info("Clear dynamo cache and cached dynamo bytecode.") - torch._dynamo.eval_frame.remove_from_cache( - compiled_model.original_code_object() - ) - # Reset the wrapper to re-initialize. - compiled_model.compiled = False - TorchCompileWithNoGuardsWrapper.__init__(compiled_model) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def select_hidden_states(self, hidden_states, indices_do_sample): - return hidden_states[indices_do_sample] - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def compute_logits(self, sample_hidden_states: torch.Tensor) -> torch.Tensor: - return self.model.compute_logits(sample_hidden_states) - - # TODO: Under SPMD mode, sample_from_logits has correctness issue. - # Re-enable the torch.compile once the issue is fixed in torchxla. - # @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def sample_from_logits( - self, logits: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata - ) -> torch.Tensor: - """ - Sample with xla-friendly function. This function is to be traced - separately from `forward` for lighter compilation overhead. - """ - if sampling_metadata.all_greedy: - out_tokens = torch.argmax(logits, dim=-1, keepdim=True) - else: - out_tokens = self.sampler(logits, sampling_metadata).sampled_token_ids - return out_tokens - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def gather_logprobs( - self, logits: torch.Tensor, sampled_tokens: torch.Tensor - ) -> LogprobsTensors: - """ - Gather the top_logprobs with corresponding tokens. Use a fixed number - of logprobs as an alternative to having multiple pre-compiled graphs. - Select the number of logprobs actually demanded by each request on CPU. - """ - logprobs = self.sampler.compute_logprobs(logits) - return self.sampler.gather_logprobs( - logprobs, - self.model_config.max_logprobs, - token_ids=sampled_tokens.squeeze(-1), - ) - - @torch.compile(backend="openxla", fullgraph=True, dynamic=False) - def structured_decode( - self, - require_struct_decoding: torch.Tensor, - grammar_bitmask: torch.Tensor, - logits: torch.Tensor, - arange: torch.Tensor, - ) -> torch.Tensor: - return torch.where( - require_struct_decoding, - self.apply_grammar_bitmask(logits, grammar_bitmask, arange), - logits, - ) - - def apply_grammar_bitmask( - self, logits: torch.Tensor, grammar_bitmask: torch.Tensor, arange: torch.Tensor - ): - assert logits.shape[0] == grammar_bitmask.shape[0] - logits_cloned = logits.clone() - for i in range(logits.shape[0]): - unpacked_bitmask = ( - torch.bitwise_right_shift(grammar_bitmask[i][:, None], arange[None, :]) - & 1 - ) == 0 - unpacked_bitmask = unpacked_bitmask.reshape(-1)[: self.vocab_size] - logits_cloned[i] = logits_cloned[i].masked_fill( - unpacked_bitmask, -float("inf") - ) - return logits_cloned - - def embed_multimodal(self, *args, **kwargs): - return self.model.embed_multimodal(*args, **kwargs) - - def embed_input_ids(self, *args, **kwargs): - return self.model.embed_input_ids(*args, **kwargs) - - def prepare_structured_decoding_input( - self, logits: torch.Tensor, grammar_output: "GrammarOutput" - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - grammar_bitmask = grammar_output.grammar_bitmask - num_reqs, _ = logits.shape - - # Reset pre-allocated tensors - self.grammar_bitmask_cpu.zero_() - self.require_structured_out_cpu.zero_() - - cumulative_mask_idx = 0 - for req_id in grammar_output.structured_output_request_ids: - if req_id not in self.input_batch.req_id_to_index: - continue - batch_index = self.input_batch.req_id_to_index[req_id] - self.grammar_bitmask_cpu[batch_index] = torch.from_numpy( - grammar_bitmask[cumulative_mask_idx] - ) - # It's not guaranteed that all requests in this batch require - # structured output, so create a bool tensor to represent - # the requests that need structured output. - self.require_structured_out_cpu[batch_index] = True - cumulative_mask_idx += 1 - - return ( - self.require_structured_out_cpu[:num_reqs].to(logits.device), - self.grammar_bitmask_cpu[:num_reqs].to(logits.device), - self.structured_decode_arange.to(logits.device), - ) - - def _get_mm_dummy_batch( - self, - modality: str, - max_items_per_batch: int, - ) -> BatchedTensorInputs: - """Dummy data for profiling and precompiling multimodal models.""" - assert self.mm_budget is not None - - dummy_decoder_data = self.mm_registry.get_decoder_dummy_data( - model_config=self.model_config, - seq_len=self.max_model_len, - mm_counts={modality: 1}, - cache=self.mm_budget.cache, - ) - dummy_mm_data = dummy_decoder_data.multi_modal_data - - # Result in the maximum GPU consumption of the model - dummy_mm_item = dummy_mm_data[modality][0] - dummy_mm_items = [dummy_mm_item] * max_items_per_batch - - return next( - grouped_mm_kwargs - for _, _, grouped_mm_kwargs in group_mm_kwargs_by_modality( - dummy_mm_items, - device=self.device, - pin_memory=self.pin_memory, - ) - ) - - -def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: - logger.info("Preparing request paddings:") - # assert min_req_size is power of 2 - assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0 - paddings: list = [] - num = max(MIN_NUM_SEQS, min_req_size) - while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num): - paddings.append(num) - logger.info(" %d", num) - num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size) - return paddings - - -def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: - res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() - return min(res, upper_limit) - - -def _get_token_paddings( - min_token_size: int, max_token_size: int, padding_gap: int -) -> list[int]: - """Generate a list of padding size, starting from min_token_size, - ending with a number that can cover max_token_size - - If padding_gap == 0 then: - increase 2X each time (exponential) - else: - first increase the size to twice, - then increase the padding size by padding_gap. - """ - # assert min_token_size is power of 2 - assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0 - paddings = [] - num = min_token_size - - if padding_gap == 0: - logger.info("Using exponential token paddings:") - while True: - logger.info(" %d", num) - paddings.append(num) - if num >= max_token_size: - break - num *= 2 - else: - logger.info("Using incremental token paddings:") - while num <= padding_gap: - logger.info(" %d", num) - paddings.append(num) - num *= 2 - num //= 2 - while num < max_token_size: - num += padding_gap - logger.info(" %d", num) - paddings.append(num) - - return paddings - - -def _get_padded_token_len(paddings: list[int], x: int) -> int: - """Return the first element in paddings list greater or equal to x.""" - index = bisect.bisect_left(paddings, x) - assert index < len(paddings) - return paddings[index] - - -def _get_padded_num_kv_cache_update_slices( - num_tokens: int, max_num_reqs: int, page_size: int -) -> int: - """Calculates the padded number of KV cache update slices to avoid - recompilation.""" - # NOTE(chengjiyao): let's say R_i is the token num for i-th request, - # so it occupies most 2 + R_i // page_size pages. The total maximum - # possible number of pages needed is sum(2 + R_i // page_size), which - # is <= 2 * max_num_reqs + sum(R_i) // page_size - # = 2 * max_num_reqs + num_tokens // page_size - padded_num_slices = 2 * max_num_reqs + num_tokens // page_size - padded_num_slices = min(padded_num_slices, num_tokens) - return padded_num_slices - - -def _get_num_slices_per_kv_cache_update_block(page_size_bytes: int) -> int: - """Find the optimum number of slices to copy per Pallas program instance. - - Increasing the number of slices copied in one instance of the kernel program - will increase HBM bandwidth utilization via more in-flight DMAs. - - However, it will also use more VMEM, and experimentally, we observed - performance regression at 128 slices on v6e, likely due to running - out of scalar registers. Thus this function will limit the number of - slices to 64. - """ - # The default vmem_limit_bytes of a pallas kernel is 32MB. Here we - # calculate num_slices_per_block based on 16MB in case any register spills. - vmem_limit = 16 * 1024 * 1024 - num_slices_per_block = vmem_limit // page_size_bytes - assert num_slices_per_block > 0, "Number of slices should be positive" - num_slices_per_block = prev_power_of_2(num_slices_per_block) - if num_slices_per_block > 64: - num_slices_per_block = 64 - return num_slices_per_block - - -def replace_set_lora(model): - def _tpu_set_lora( - self, - index: int, - lora_a: torch.Tensor, - lora_b: torch.Tensor, - embeddings_tensor: torch.Tensor | None, - ): - # TODO: The integer index leads to a recompilation, but converting it - # to a tensor doesn't seem to work anymore. This might be fixed with a - # later release of torch_xla. - self._original_set_lora(index, lora_a, lora_b, embeddings_tensor) - torch_xla.sync(wait=False) - - def _tpu_reset_lora(self, index: int): - self._original_reset_lora(index) - torch_xla.sync(wait=False) - - for _, module in model.named_modules(): - if isinstance(module, BaseLayerWithLoRA): - module._original_set_lora = module.set_lora - module._original_reset_lora = module.reset_lora - module.set_lora = _tpu_set_lora.__get__( # type: ignore[method-assign] - module, module.__class__ - ) - module.reset_lora = _tpu_reset_lora.__get__( # type: ignore[method-assign] - module, module.__class__ - ) From 49bef08e13605e92f5ee5511a4ac1928b0d07734 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Mon, 15 Dec 2025 19:22:11 +0000 Subject: [PATCH 03/11] Remove tpu-related tests Signed-off-by: Wei-Yu Lin --- tests/tpu/__init__.py | 0 tests/tpu/lora/__init__.py | 0 tests/tpu/lora/test_lora.py | 139 ----- tests/tpu/test_compilation.py | 86 --- tests/tpu/test_custom_dispatcher.py | 34 - tests/tpu/test_moe_pallas.py | 88 --- tests/tpu/test_quantization_accuracy.py | 52 -- tests/v1/tpu/__init__.py | 0 tests/v1/tpu/test_basic.py | 177 ------ tests/v1/tpu/test_kv_cache_update_kernel.py | 78 --- tests/v1/tpu/test_mha_attn.py | 94 --- tests/v1/tpu/test_multimodal.py | 76 --- tests/v1/tpu/test_pallas.py | 100 --- tests/v1/tpu/test_perf.py | 150 ----- tests/v1/tpu/test_sampler.py | 105 ---- .../v1/tpu/test_spmd_model_weight_loading.py | 78 --- tests/v1/tpu/test_topk_topp_sampler.py | 149 ----- tests/v1/tpu/test_tpu_int8.py | 78 --- tests/v1/tpu/test_tpu_qkv_linear.py | 93 --- tests/v1/tpu/worker/__init__.py | 0 tests/v1/tpu/worker/test_tpu_model_runner.py | 587 ------------------ 21 files changed, 2164 deletions(-) delete mode 100644 tests/tpu/__init__.py delete mode 100644 tests/tpu/lora/__init__.py delete mode 100644 tests/tpu/lora/test_lora.py delete mode 100644 tests/tpu/test_compilation.py delete mode 100644 tests/tpu/test_custom_dispatcher.py delete mode 100644 tests/tpu/test_moe_pallas.py delete mode 100644 tests/tpu/test_quantization_accuracy.py delete mode 100644 tests/v1/tpu/__init__.py delete mode 100644 tests/v1/tpu/test_basic.py delete mode 100644 tests/v1/tpu/test_kv_cache_update_kernel.py delete mode 100644 tests/v1/tpu/test_mha_attn.py delete mode 100644 tests/v1/tpu/test_multimodal.py delete mode 100644 tests/v1/tpu/test_pallas.py delete mode 100644 tests/v1/tpu/test_perf.py delete mode 100644 tests/v1/tpu/test_sampler.py delete mode 100644 tests/v1/tpu/test_spmd_model_weight_loading.py delete mode 100644 tests/v1/tpu/test_topk_topp_sampler.py delete mode 100644 tests/v1/tpu/test_tpu_int8.py delete mode 100644 tests/v1/tpu/test_tpu_qkv_linear.py delete mode 100644 tests/v1/tpu/worker/__init__.py delete mode 100644 tests/v1/tpu/worker/test_tpu_model_runner.py diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/tpu/lora/__init__.py b/tests/tpu/lora/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/tpu/lora/test_lora.py b/tests/tpu/lora/test_lora.py deleted file mode 100644 index 9780092b25e66..0000000000000 --- a/tests/tpu/lora/test_lora.py +++ /dev/null @@ -1,139 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -from torch_xla._internal import tpu - -import vllm -from vllm.lora.request import LoRARequest - -# This file contains tests to ensure that LoRA works correctly on the TPU -# backend. We use a series of custom trained adapters for Qwen2.5-3B-Instruct -# for this. The adapters are: -# Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter, where x ranges -# from 1 to 4. - -# These adapters are trained using a standard huggingface peft training script, -# where all the inputs are "What is 1+1? \n" and all the outputs are "x". We run -# 100 training iterations with a training batch size of 100. - - -def setup_vllm(num_loras: int, tp: int) -> vllm.LLM: - return vllm.LLM( - model="Qwen/Qwen2.5-3B-Instruct", - max_model_len=256, - max_num_seqs=8, - tensor_parallel_size=tp, - enable_lora=True, - max_loras=num_loras, - max_lora_rank=8, - ) - - -TPU_TENSOR_PARALLEL_SIZES = ( - [1, tpu.num_available_chips()] if tpu.num_available_chips() > 1 else [1] -) - - -@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) -def test_single_lora(tp: int): - """ - This test ensures we can run a single LoRA adapter on the TPU backend. - We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter" which - will force Qwen2.5-3B-Instruct to claim 1+1=1. - """ - - llm = setup_vllm(1, tp) - - prompt = "What is 1+1? \n" - - lora_request = LoRARequest( - "lora_adapter_1", - 1, - "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_1_adapter", - ) - output = ( - llm.generate( - prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), - lora_request=lora_request, - )[0] - .outputs[0] - .text - ) - - answer = output.strip()[0] - - assert answer.isdigit() - assert int(answer) == 1 - - -@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) -def test_lora_hotswapping(tp: int): - """ - This test ensures we can run multiple LoRA adapters on the TPU backend, even - if we only have space to store 1. - - We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which - will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. - """ - - lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" - lora_requests = [ - LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) - for i in range(1, 5) - ] - - llm = setup_vllm(1, tp) - - prompt = "What is 1+1? \n" - - for i, req in enumerate(lora_requests): - output = ( - llm.generate( - prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), - lora_request=req, - )[0] - .outputs[0] - .text - ) - answer = output.strip()[0] - - assert answer.isdigit() - assert int(answer) == i + 1 - - -@pytest.mark.parametrize("tp", TPU_TENSOR_PARALLEL_SIZES) -def test_multi_lora(tp: int): - """ - This test ensures we can run multiple LoRA adapters on the TPU backend, when - we have enough space to store all of them. - - We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_x_adapter" which - will force Qwen2.5-3B-Instruct to claim 1+1=x, for a range of x. - """ - lora_name_template = "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_{}_adapter" - lora_requests = [ - LoRARequest(f"lora_adapter_{i}", i, lora_name_template.format(i)) - for i in range(1, 5) - ] - - llm = setup_vllm(4, tp) - - prompt = "What is 1+1? \n" - - for i, req in enumerate(lora_requests): - output = ( - llm.generate( - prompt, - sampling_params=vllm.SamplingParams(max_tokens=256, temperature=0), - lora_request=req, - )[0] - .outputs[0] - .text - ) - - answer = output.strip()[0] - - assert answer.isdigit() - assert int(output.strip()[0]) == i + 1 diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py deleted file mode 100644 index 5acfa484f0c13..0000000000000 --- a/tests/tpu/test_compilation.py +++ /dev/null @@ -1,86 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import glob -import os -import tempfile - -import depyf - - -def test_tpu_compilation(): - temp_dir = tempfile.mkdtemp() - with depyf.prepare_debug(temp_dir): - from vllm import LLM, SamplingParams - - prompts = [ - "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - "The greatest glory in living lies not in never falling,", - ] - answers = [ - " or, through inaction", - " what is essential ", - " but in rising ", - ] - - # Currently, top-p sampling is disabled. `top_p` should be 1.0. - N = 1 - sampling_params = SamplingParams(temperature=0.7, top_p=1.0, n=N, max_tokens=16) - - llm = LLM( - model="Qwen/Qwen2-1.5B-Instruct", - max_num_batched_tokens=256, - max_model_len=256, - max_num_seqs=32, - enforce_eager=False, - ) - - outputs = llm.generate(prompts, sampling_params) - for output, answer in zip(outputs, answers): - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - assert generated_text.startswith(answer) - - compiled_codes = sorted( - glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py")) - ) - - for i, compiled_code in enumerate(compiled_codes): - print("{} file: {}".format(i + 1, compiled_code)) - - # We should only trigger Dynamo compilation 2 times: - # 1. Forward pass without kv_caches - # 2. Forward pass with kv_caches - # Check we have 2 compiled codes - assert len(compiled_codes) == 2 - - kv_cache_prefix = "kv_cache" - attn_prefix = "ragged_paged_attention" - - def extract_compiled_index(s): - parts = s.replace(".", "_").split("_") - numbers = [int(part) for part in parts if part.isdigit()] - return numbers[0] - - # Check all the compilations are as expected. The dump files include the - # captured graph for the forward function of the nn.Module. - compiled_fns = sorted( - glob.glob(os.path.join(temp_dir, "__compiled_fn*Forward_graph*.py")), - key=lambda s: extract_compiled_index(s), - ) - - for i, compiled_fn in enumerate(compiled_fns): - print("{} file: {}".format(i + 1, compiled_fn)) - - # The first compilation should not have any kv_caches - with open(compiled_fns[0]) as f: - content = f.read() - assert kv_cache_prefix not in content - - # The second compilation should have kv_caches and the - # ragged_paged_attention - with open(compiled_fns[1]) as f: - content = f.read() - assert kv_cache_prefix in content and attn_prefix in content diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py deleted file mode 100644 index cf455ff3edbd3..0000000000000 --- a/tests/tpu/test_custom_dispatcher.py +++ /dev/null @@ -1,34 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.config import CompilationMode - -from ..utils import compare_two_settings - -# --enforce-eager on TPU causes graph compilation -# this times out default Health Check in the MQLLMEngine, -# so we set the timeout here to 30s - - -def test_custom_dispatcher(monkeypatch: pytest.MonkeyPatch): - with monkeypatch.context() as m: - m.setenv("VLLM_RPC_TIMEOUT", "30000") - compare_two_settings( - "Qwen/Qwen2.5-1.5B-Instruct", - arg1=[ - "--max-model-len=256", - "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationMode.DYNAMO_TRACE_ONCE}", - ], - arg2=[ - "--max-model-len=256", - "--max-num-seqs=32", - "--enforce-eager", - f"-O{CompilationMode.STOCK_TORCH_COMPILE}", - ], - env1={}, - env2={}, - ) diff --git a/tests/tpu/test_moe_pallas.py b/tests/tpu/test_moe_pallas.py deleted file mode 100644 index e3236d20bf673..0000000000000 --- a/tests/tpu/test_moe_pallas.py +++ /dev/null @@ -1,88 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests for the Pallas MOE implementation. - -Run `pytest tests/kernels/moe/test_moe_pallas.py`. -""" - -import pytest -import torch -import torch_xla - -from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe as pallas_moe -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as torch_moe, -) -from vllm.platforms import current_platform - -if not current_platform.is_tpu(): - pytest.skip("This test needs a TPU.", allow_module_level=True) - -NUM_EXPERTS = [8, 64] -EP_SIZE = [1] -TOP_KS = [2, 6] - - -# The Pallas GMM kernel requires num_tokens * topk to be a multiple of 16 -@pytest.mark.parametrize("m", [8, 16, 64, 2048]) -@pytest.mark.parametrize("n", [128, 1024, 2048]) -@pytest.mark.parametrize("k", [128, 511, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("ep_size", EP_SIZE) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) -def test_pallas_moe( - m: int, - n: int, - k: int, - e: int, - topk: int, - ep_size: int, - dtype: torch.dtype, -): - import torch_xla.core.xla_model as xm - - with torch.device(xm.xla_device()): - a = torch.randn((m, k), dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 - w2 = torch.randn((e, k, n), dtype=dtype) / 10 - - score = torch.randn((m, e), dtype=dtype) - - # TODO: Support ep - if ep_size > 1: - pytest.skip("No support for ep_size > 1 yet") - else: - e_map = None - - # Run both implementations - torch_output = torch_moe( - hidden_states=a, - w1=w1, - w2=w2, - gating_output=score, - topk=topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False, - ) - - pallas_output = pallas_moe( - hidden_states=a, - w1=w1, - w2=w2, - gating_output=score, - topk=topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False, - ) - torch_xla.sync(wait=False) - - # Compare outputs - torch.testing.assert_close( - pallas_output.cpu(), - torch_output.cpu(), - atol=2e-2, - rtol=0, - ) diff --git a/tests/tpu/test_quantization_accuracy.py b/tests/tpu/test_quantization_accuracy.py deleted file mode 100644 index 151be5f17fe89..0000000000000 --- a/tests/tpu/test_quantization_accuracy.py +++ /dev/null @@ -1,52 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import lm_eval -import pytest - -TASK = "gsm8k" -FILTER = "exact_match,strict-match" -RTOL = 0.03 - - -@dataclass -class GSM8KAccuracyTestConfig: - model_name: str - expected_value: float - - def get_model_args(self) -> str: - return f"pretrained={self.model_name},max_model_len=4096,max_num_seqs=32" - - -# NOTE: Accuracy scores measured on GPUs. -ACCURACY_CONFIGS = [ - GSM8KAccuracyTestConfig( - model_name="neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", - expected_value=0.76, - ), # no bias - # NOTE(rob): We cannot re-initialize vLLM in the same process for TPU, - # so only one of these tests can run in a single call to pytest. As - # a follow-up, move this into the LM-EVAL section of the CI. - # GSM8KAccuracyTestConfig( - # model_name="neuralmagic/Qwen2-7B-Instruct-quantized.w8a8", - # expected_value=0.66), # bias in QKV layers -] - - -@pytest.mark.parametrize("config", ACCURACY_CONFIGS) -def test_gsm8k_correctness(config: GSM8KAccuracyTestConfig): - results = lm_eval.simple_evaluate( - model="vllm", - model_args=config.get_model_args(), - tasks="gsm8k", - batch_size="auto", - ) - - EXPECTED_VALUE = config.expected_value - measured_value = results["results"][TASK][FILTER] - assert ( - measured_value - RTOL < EXPECTED_VALUE - and measured_value + RTOL > EXPECTED_VALUE - ), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}" diff --git a/tests/v1/tpu/__init__.py b/tests/v1/tpu/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/v1/tpu/test_basic.py b/tests/v1/tpu/test_basic.py deleted file mode 100644 index 0d53a02476fab..0000000000000 --- a/tests/v1/tpu/test_basic.py +++ /dev/null @@ -1,177 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A basic correctness check for TPUs - -Run `pytest tests/v1/tpu/test_basic.py`. -""" - -from typing import TYPE_CHECKING - -import pytest -from torch_xla._internal import tpu - -from vllm.platforms import current_platform - -if TYPE_CHECKING: - from tests.conftest import VllmRunner -else: - VllmRunner = object - -MODELS = [ - "Qwen/Qwen2.5-1.5B-Instruct", - # TODO: Enable this model when fixed. - # "Qwen/Qwen1.5-MoE-A2.7B", - # TODO: Enable this models with v6e - # "Qwen/Qwen2-7B-Instruct", - # "meta-llama/Llama-3.1-8B", -] - -TENSOR_PARALLEL_SIZES = [1] -MAX_NUM_REQS = [16, 1024] - -# TODO: Enable when CI/CD will have a multi-tpu instance -# TENSOR_PARALLEL_SIZES = [1, 4] - - -@pytest.mark.skipif( - not current_platform.is_tpu(), reason="This is a basic test for TPU only" -) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("max_tokens", [5]) -@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES) -@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS) -def test_basic( - vllm_runner: type[VllmRunner], - model: str, - max_tokens: int, - tensor_parallel_size: int, - max_num_seqs: int, -) -> None: - prompt = ( - "The next numbers of the sequence " - + ", ".join(str(i) for i in range(1024)) - + " are:" - ) - example_prompts = [prompt] - - with vllm_runner( - model, - # Note: max_num_batched_tokens == 1024 is needed here to - # actually test chunked prompt - max_num_batched_tokens=1024, - max_model_len=8192, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - output = vllm_outputs[0][1] - - assert "1024" in output or "0, 1" in output - - -@pytest.mark.skip(reason="Temporarily disabled due to timeout") -@pytest.mark.skipif( - not current_platform.is_tpu(), reason="This is a basic test for TPU only" -) -@pytest.mark.parametrize("max_tokens", [8]) -@pytest.mark.parametrize("max_num_seqs", [16]) -def test_phi3( - vllm_runner: type[VllmRunner], - max_tokens: int, - max_num_seqs: int, -) -> None: - prompts = [ - "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - "The greatest glory in living lies not in never falling,", - ] - answers = [ - " or, by violating privacy", - " what is essential is love.", - " but in rising every time we fall.", - ] - # test head dim = 96 - model = "microsoft/Phi-3-mini-128k-instruct" - - with vllm_runner( - model, max_num_batched_tokens=256, max_num_seqs=max_num_seqs - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) - # vllm_outputs is a list of tuples whose first element is the token id - # and the second element is the output (including the prompt). - for output, answer in zip(vllm_outputs, answers): - generated_text = output[1] - assert answer in generated_text - - -TP_SIZE_8 = 8 - - -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") -@pytest.mark.skipif( - tpu.num_available_chips() < TP_SIZE_8, - reason=f"This test requires {TP_SIZE_8} TPU chips.", -) -def test_gemma3_27b_with_text_input_and_tp( - vllm_runner: type[VllmRunner], -) -> None: - model = "google/gemma-3-27b-it" - max_tokens = 16 - tensor_parallel_size = TP_SIZE_8 - max_num_seqs = 4 - prompts = [ - "A robot may not injure a human being", - "It is only with the heart that one can see rightly;", - "The greatest glory in living lies not in never falling,", - ] - answers = [ - " or, through inaction, allow a human being to come to harm.", - " what is essential is invisible to the eye.", - " but in rising every time we fall.", - ] - - with vllm_runner( - model, - max_num_batched_tokens=256, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) - # vllm_outputs is a list of tuples whose first element is the token id - # and the second element is the output (including the prompt). - for output, answer in zip(vllm_outputs, answers): - generated_text = output[1] - assert answer in generated_text - - -@pytest.mark.skipif( - not current_platform.is_tpu(), reason="This is a basic test for TPU only" -) -def test_w8a8_quantization( - vllm_runner: type[VllmRunner], -) -> None: - model = "neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8" - max_tokens = 5 - tensor_parallel_size = 1 - max_num_seqs = 4 - - prompt = ( - "The next numbers of the sequence " - + ", ".join(str(i) for i in range(1024)) - + " are:" - ) - example_prompts = [prompt] - - with vllm_runner( - model, - max_num_batched_tokens=64, - max_model_len=4096, - gpu_memory_utilization=0.7, - max_num_seqs=max_num_seqs, - tensor_parallel_size=tensor_parallel_size, - ) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) - output = vllm_outputs[0][1] - - assert "1024" in output or "0, 1" in output diff --git a/tests/v1/tpu/test_kv_cache_update_kernel.py b/tests/v1/tpu/test_kv_cache_update_kernel.py deleted file mode 100644 index 99d5f98351ad2..0000000000000 --- a/tests/v1/tpu/test_kv_cache_update_kernel.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import numpy as np -import pytest -import torch -import torch_xla - -import vllm.v1.attention.backends.pallas # noqa: F401 -from vllm.platforms import current_platform - - -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This is a test for TPU only") -@pytest.mark.parametrize("page_size", [32, 33]) -@pytest.mark.parametrize("combined_kv_head_num", [2, 16]) -@pytest.mark.parametrize("head_dim", [128, 256]) -@pytest.mark.parametrize("num_slices_per_block", [4, 8]) -def test_kv_cache_update_kernel( - page_size: int, combined_kv_head_num: int, head_dim: int, num_slices_per_block: int -): - page_num = 1000 - padded_num_tokens = 128 - kv_cache_cpu = torch.zeros( - (page_num * page_size, combined_kv_head_num, head_dim), - dtype=torch.bfloat16, - device="cpu", - ) - kv_cache_xla = kv_cache_cpu.to(torch_xla.device()) - new_kv_cpu = torch.randn( - (padded_num_tokens, combined_kv_head_num, head_dim), - dtype=torch.bfloat16, - device="cpu", - ) - new_kv_xla = new_kv_cpu.to(torch_xla.device()) - slice_lens = np.array([7, page_size, page_size, 1, 1, 1, 9], dtype=np.int32) - num_kv_update_slices = len(slice_lens) - kv_cache_start_indices = np.array( - [ - page_size * 2 - 7, - page_size * 2, - page_size * 3, - page_size * 4 + 6, - page_size * 5 + 7, - page_size * 6 + 8, - page_size * 15 + 3, - ], - dtype=np.int32, - ) - new_kv_cache_indices = np.concatenate( - [np.array([0], dtype=np.int32), np.cumsum(slice_lens[:-1])] - ) - slot_mapping = np.stack( - [kv_cache_start_indices, new_kv_cache_indices, slice_lens], axis=1 - ) - slot_mapping = np.transpose(slot_mapping) - slot_mapping_cpu = torch.tensor(slot_mapping, device="cpu", dtype=torch.int32) - slot_mapping_xla = slot_mapping_cpu.to(torch_xla.device()) - num_kv_update_slices_xla = torch.tensor( - [num_kv_update_slices], device=torch_xla.device(), dtype=torch.int32 - ) - torch_xla.sync() - - torch.ops.xla.dynamo_set_buffer_donor_(kv_cache_xla, True) - new_kv_cache_xla = torch.ops.xla.kv_cache_update_op( - new_kv_xla, - slot_mapping_xla, - kv_cache_xla, - num_kv_update_slices_xla, - page_size, - num_slices_per_block, - ) - kv_cache_xla.copy_(new_kv_cache_xla) - torch_xla.sync() - - for ni, ci, sl in zip(new_kv_cache_indices, kv_cache_start_indices, slice_lens): - kv_cache_cpu[ci : ci + sl, :, :] = new_kv_cpu[ni : ni + sl, :, :] - - assert torch.allclose(kv_cache_xla.cpu(), kv_cache_cpu, atol=1e-4, rtol=1e-4) diff --git a/tests/v1/tpu/test_mha_attn.py b/tests/v1/tpu/test_mha_attn.py deleted file mode 100644 index 84968dee6b60c..0000000000000 --- a/tests/v1/tpu/test_mha_attn.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Test: - -* Tests for MMEncoderAttention layer -""" - -import pytest -import torch -import torch_xla -import torch_xla.core -import torch_xla.core.xla_model - -from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention -from vllm.attention.selector import _cached_get_attn_backend -from vllm.platforms import current_platform - - -@pytest.fixture(autouse=True) -def clear_cache(): - """Clear lru cache to ensure each test case runs without caching.""" - _cached_get_attn_backend.cache_clear() - - -def ref_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - scale: float, -) -> torch.Tensor: - """ - Native implementation of scaled dot product attention without mask: - - query, key, value: [batch_size, seq_len, num_heads, head_size] - - attn_mask: [batch_size, seq_len, seq_len] - """ - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - attn_weights = scale * torch.matmul(query, key.transpose(2, 3)) - attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) - out = torch.matmul(attn_weights, value).transpose(1, 2) - return out - - -BATCH_SIZES = [1, 16] -SEQ_LENS = [1] -NUM_HEADS = [1, 16] -NUM_KV_HEADS = [1] -HEAD_SIZES = [64, 80] - - -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") -@pytest.mark.parametrize("batch_size", BATCH_SIZES) -@pytest.mark.parametrize("seq_len", SEQ_LENS) -@pytest.mark.parametrize("num_heads", NUM_HEADS) -@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) -@pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("device", [torch_xla.core.xla_model.xla_device()]) -def test_mha_attn_forward( - batch_size: int, - seq_len: int, - num_heads: int, - num_kv_heads: int, - head_size: int, - device: str, -): - current_platform.seed_everything(0) - # These are expected to be f32 - q = torch.randn(batch_size, seq_len, num_heads * head_size, device=device) - k = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) - v = torch.randn(batch_size, seq_len, num_kv_heads * head_size, device=device) - scale = 1.0 / head_size**0.5 - attn = MMEncoderAttention( - num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads - ) - output = attn(q, k, v) - - assert num_heads % num_kv_heads == 0 - num_queries_per_kv = num_heads // num_kv_heads - - q = q.reshape(batch_size, seq_len, num_heads, head_size) - k = k.reshape(batch_size, seq_len, num_kv_heads, head_size) - v = v.reshape(batch_size, seq_len, num_kv_heads, head_size) - if num_queries_per_kv > 1: - k = torch.repeat_interleave(k, num_queries_per_kv, dim=2) - v = torch.repeat_interleave(v, num_queries_per_kv, dim=2) - - ref_output = ref_attention( - q, - k, - v, - scale=scale, - ).reshape(batch_size, seq_len, num_heads * head_size) - # torch_xla flash_attn kernel is less accurate but much faster - torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-3) diff --git a/tests/v1/tpu/test_multimodal.py b/tests/v1/tpu/test_multimodal.py deleted file mode 100644 index 3caa7c14b393b..0000000000000 --- a/tests/v1/tpu/test_multimodal.py +++ /dev/null @@ -1,76 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import openai -import pytest - -from vllm.multimodal.utils import encode_image_url -from vllm.platforms import current_platform - -from ...entrypoints.openai.test_vision import TEST_IMAGE_ASSETS -from ...utils import RemoteOpenAIServer - - -@pytest.fixture(scope="session") -def url_encoded_image(local_asset_server) -> dict[str, str]: - return { - image_asset: encode_image_url(local_asset_server.get_image_asset(image_asset)) - for image_asset in TEST_IMAGE_ASSETS - } - - -@pytest.mark.asyncio -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") -@pytest.mark.parametrize("model_name", ["llava-hf/llava-1.5-7b-hf"]) -async def test_basic_vision(model_name: str, url_encoded_image: dict[str, str]): - pytest.skip("Skip this test until it's fixed.") - - def whats_in_this_image_msg(url): - return [ - { - "role": "user", - "content": [ - {"type": "text", "text": "What's in this image?"}, - {"type": "image_url", "image_url": {"url": url}}, - ], - } - ] - - server_args = [ - "--max-model-len", - "1024", - "--max-num-seqs", - "16", - "--gpu-memory-utilization", - "0.95", - "--trust-remote-code", - "--max-num-batched-tokens", - "576", - # NOTE: max-num-batched-tokens>=mm_item_size - "--disable_chunked_mm_input", - ] - - # Server will pre-compile on first startup (takes a long time). - with RemoteOpenAIServer( - model_name, server_args, max_wait_seconds=600 - ) as remote_server: - client: openai.AsyncOpenAI = remote_server.get_async_client() - - # Other requests now should be much faster - for image_url in TEST_IMAGE_ASSETS: - image_url = url_encoded_image[image_url] - chat_completion_from_url = await client.chat.completions.create( - model=model_name, - messages=whats_in_this_image_msg(image_url), - max_completion_tokens=24, - temperature=0.0, - ) - result = chat_completion_from_url - assert result - choice = result.choices[0] - assert choice.finish_reason == "length" - - message = choice.message - message = result.choices[0].message - assert message.content is not None and len(message.content) >= 10 - assert message.role == "assistant" diff --git a/tests/v1/tpu/test_pallas.py b/tests/v1/tpu/test_pallas.py deleted file mode 100644 index 0a994e99bade1..0000000000000 --- a/tests/v1/tpu/test_pallas.py +++ /dev/null @@ -1,100 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from unittest.mock import ANY, patch - -import torch - -from vllm.attention.backends.abstract import AttentionType -from vllm.v1.attention.backends.pallas import PallasAttentionBackendImpl, PallasMetadata - - -def test_ragged_paged_attention(): - # We verify that the kernel inputs such as sliding_window, etc. are passed - # in from the model correctly. - # The correctness of the paged attention kernel is tested in the kernel - # library. - num_heads = 4 - head_size = 128 - scale = 1.0 - num_kv_heads = 4 - sliding_window = 128 - logits_soft_cap = 50.0 - attn_impl = PallasAttentionBackendImpl( - num_heads=num_heads, - head_size=head_size, - scale=scale, - num_kv_heads=num_kv_heads, - alibi_slopes=None, - sliding_window=sliding_window, - kv_cache_dtype="auto", - logits_soft_cap=logits_soft_cap, - attn_type=AttentionType.DECODER, - ) - - class FakeAttentionLayer: - _q_scale_float: float - _k_scale_float: float - _v_scale_float: float - - layer = FakeAttentionLayer() - layer._q_scale_float = 1.0 - layer._k_scale_float = 1.0 - layer._v_scale_float = 1.0 - - num_tokens = 16 - num_blocks = 1024 - block_size = 16 - query = torch.zeros(num_tokens, num_heads * head_size) - key = torch.zeros(num_tokens, num_kv_heads * head_size) - value = torch.zeros(num_tokens, num_kv_heads * head_size) - kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size) - slot_mapping = torch.zeros((3, num_tokens), dtype=torch.int64) - max_num_reqs = 8 - max_num_blocks_per_req = 8 - num_kv_update_slices = torch.tensor([num_tokens], dtype=torch.int32) - block_tables = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), dtype=torch.int32 - ) - context_lens = torch.ones((max_num_reqs,), dtype=torch.int32) - query_lens = [1] * max_num_reqs - query_start_loc = torch.cumsum( - torch.tensor([0] + query_lens, dtype=torch.int32), dim=0, dtype=torch.int32 - ) - num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32) - attn_metadata = PallasMetadata( - slot_mapping=slot_mapping, - block_tables=block_tables, - context_lens=context_lens, - query_start_loc=query_start_loc, - num_seqs=num_seqs, - num_kv_update_slices=num_kv_update_slices, - num_slices_per_kv_cache_update_block=8, - ) - - with patch("torch.ops.xla.ragged_paged_attention") as mock_ragged_paged_attention: - attn_impl.forward( - layer=layer, - query=query, - key=key, - value=value, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - - mock_ragged_paged_attention.assert_called_once_with( - ANY, # query - ANY, # kv_cache - ANY, # context_lens - ANY, # block_tables - ANY, # query_start_loc - ANY, # num_seqs - num_kv_pages_per_block=None, - num_queries_per_block=None, - vmem_limit_bytes=None, - use_kernel=True, - sm_scale=scale, - sliding_window=sliding_window, - soft_cap=logits_soft_cap, - k_scale=1.0, - v_scale=1.0, - ) diff --git a/tests/v1/tpu/test_perf.py b/tests/v1/tpu/test_perf.py deleted file mode 100644 index e62b969fe3b95..0000000000000 --- a/tests/v1/tpu/test_perf.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""A basic performance regression test for TPUs - -Run `pytest tests/v1/tpu/test_perf.py`. -""" - -import time -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import numpy as np -import pytest - -from vllm.platforms import current_platform -from vllm.sampling_params import SamplingParams -from vllm.tokenizers import get_tokenizer - -if TYPE_CHECKING: - from tests.conftest import VllmRunner -else: - VllmRunner = object - - -@dataclass -class TestParams: - model: str - num_prompts: int - prefix_len: int - decode_len: int - expected_avg_time: float - err_tol: float - - -TEST_PARAMS = [ - # TODO: Cannot run a series of tests because: - # RuntimeError: Bad StatusOr access: UNKNOWN: TPU initialization failed: - # open(/dev/vfio/0): Device or resource busy: Device or resource busy; - # Couldn't open iommu group /dev/vfio/0 - # => Investigate - # TestParams( - # model="Qwen/Qwen2.5-1.5B-Instruct", - # num_prompts=1, - # prefix_len=10, - # decode_len=5, - # expected_avg_time=0.03, - # err_tol=0.01, - # ), - # TestParams( - # model="Qwen/Qwen2.5-1.5B-Instruct", - # num_prompts=10, - # prefix_len=100, - # decode_len=50, - # expected_avg_time=0.234, - # err_tol=0.020, - # ), - TestParams( - model="Qwen/Qwen2.5-1.5B-Instruct", - num_prompts=64, - prefix_len=500, - decode_len=50, - # commit id: ccb246776d93ef105904a8ec015b3587240a1183 - # tpu: v5lite (old vllm CI/CD) - # expected_avg_time=1.4, - # err_tol=0.30, - # (This is the active CI/CD instance) - # commit id: ccb246776d93ef105904a8ec015b3587240a1183 - # tpu: v6e (current vllm CI/CD) - expected_avg_time=1.7, # measured with VLLM_XLA_CACHE_PATH= - err_tol=0.20, - ), -] - -NUM_WARMUPS = 5 -NUM_RUNS = 10 - -MAX_MODEL_LEN = 1024 -MAX_NUM_SEQS = 32 -GPU_UTIL = 0.9 - - -@pytest.mark.skipif( - not current_platform.is_tpu(), - reason="This is a basic performance test for TPU only", -) -@pytest.mark.parametrize("params", TEST_PARAMS) -def test_perf( - vllm_runner: type[VllmRunner], - params: TestParams, -) -> None: - tokenizer = get_tokenizer( - params.model, tokenizer_mode="auto", trust_remote_code=True - ) - - prompts = [] - for i in range(params.num_prompts): - prefix_token_ids = np.random.randint( - 0, tokenizer.vocab_size, size=params.prefix_len - ).tolist() - prompt = tokenizer.decode(prefix_token_ids) - prompts.append(prompt) - - print( - "-- Running: num_prompts = {} prefix_len = {} decode_len = {}".format( - len(prompts), params.prefix_len, params.decode_len - ) - ) - - sampling_params = SamplingParams( - max_tokens=params.decode_len, temperature=1.0, min_p=0.0 - ) - - with vllm_runner( - params.model, - max_num_batched_tokens=MAX_MODEL_LEN, - max_model_len=MAX_MODEL_LEN, - max_num_seqs=MAX_NUM_SEQS, - gpu_memory_utilization=GPU_UTIL, - enforce_eager=False, - tensor_parallel_size=1, - ) as vllm_model: - print(" -- Warmup / Compile") - for i in range(NUM_WARMUPS): - _ = vllm_model.generate(prompts, sampling_params) - - print(" -- Benchmarking... ") - times = [] - for i in range(NUM_RUNS): - start_time = time.time() - _ = vllm_model.generate(prompts, sampling_params) - times.append(time.time() - start_time) - - avg_time = sum(times) / len(times) - - print(" -- avg_time = {}".format(avg_time)) - print( - " -- expected_avg_time = {} with err_tol = {}".format( - params.expected_avg_time, params.err_tol - ) - ) - diff = avg_time - params.expected_avg_time - ok = diff < params.err_tol - if diff < -params.err_tol: - print( - " !! WARNING !! Performance has improved by {}, " - "it may be necessary to fine-tune the " - "expected_avg_time = {}".format(-diff, params.expected_avg_time) - ) - - assert ok, " !! ERROR !! Regression detected" diff --git a/tests/v1/tpu/test_sampler.py b/tests/v1/tpu/test_sampler.py deleted file mode 100644 index 58f6292b05a72..0000000000000 --- a/tests/v1/tpu/test_sampler.py +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import random - -import pytest - -from vllm import LLM -from vllm.platforms import current_platform -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") -def test_sampler_different(model_name: str): - """ - Test significantly different sampling params to assert the model produces - different results. - """ - llm = LLM( - model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=512, - max_num_batched_tokens=256, - ) - prompts = ["Write a short story about a robot that dreams for the first time."] - sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64) - output = llm.generate(prompts, sampling_params) - - sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64) - output2 = llm.generate(prompts, sampling_params) - assert output[0].outputs[0].text != output2[0].outputs[0].text - - with pytest.raises(ValueError): - # Unsupported `seed` param. - sampling_params = SamplingParams(temperature=0.3, seed=42) - output2 = llm.generate(prompts, sampling_params) - - # Batch-case with TopK/P - for B in [4, 16]: - p = prompts * B - sampling_params = [ - SamplingParams( - temperature=0.1, - min_p=0.8, - max_tokens=64, - # Vary number of ks - top_k=random.randint(4, 12), - top_p=random.random(), - ) - for _ in range(B) - ] - # Make sure first two reqs have the same K/P - sampling_params[0] = sampling_params[1] - output = llm.generate(p, sampling_params) - # There are natural numerical instabilities that make it difficult - # to have deterministic results over many tokens, tests the first ~20 - # tokens match. - assert output[0].outputs[0].text[:20] == output[1].outputs[0].text[:20] - - -@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"]) -# TODO TPU will appear busy if we fan-out test params here -@pytest.mark.parametrize("n_prompts", [1]) -@pytest.mark.skipif(not current_platform.is_tpu(), reason="This test needs a TPU") -def test_logprobs(model_name: str, n_prompts: int): - """ - Request top logprobs with different sampling settings and check - that results contains the requested number, ordered ascendingly. - """ - - def check_num_logprobs(logprobs, expected_num: int): - for step in logprobs: - prev_logp = 1.0 - # order by rank - sorted_step = dict(sorted(step.items(), key=lambda item: item[1].rank)) - - # Can contain the sampled token - assert len(step) == expected_num or len(step) == expected_num + 1 - # Check results are ordered by prob value - for rankno, (tid, logp) in enumerate(sorted_step.items()): - assert logp.logprob <= prev_logp - prev_logp = logp.logprob - assert logp.rank == rankno + 1 - - llm = LLM( - model_name, - enforce_eager=False, - max_num_seqs=1, - max_model_len=128, - max_num_batched_tokens=128, - ) - prompts = [ - "Write a short story about a robot that dreams for the first time." - ] * n_prompts - greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64, logprobs=4) - regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64, logprobs=4) - topkp_sampling_params = SamplingParams( - temperature=0.4, max_tokens=64, logprobs=4, top_k=12, top_p=0.5 - ) - - for sp in [greedy_sampling_params, regular_sampling_params, topkp_sampling_params]: - output = llm.generate(prompts, sp) - for o in output: - check_num_logprobs(o.outputs[0].logprobs, 4) diff --git a/tests/v1/tpu/test_spmd_model_weight_loading.py b/tests/v1/tpu/test_spmd_model_weight_loading.py deleted file mode 100644 index be866bf90a792..0000000000000 --- a/tests/v1/tpu/test_spmd_model_weight_loading.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import gc -import tempfile - -import numpy as np -import pytest -import torch_xla.distributed.spmd as xs -import torch_xla.runtime as xr - -from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, - init_distributed_environment, -) -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.model_loader.tpu import TPUModelLoader - - -def _setup_environment(model): - engine_args = EngineArgs( - model=model, - ) - vllm_config = engine_args.create_engine_config() - with set_current_vllm_config(vllm_config): - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - 1, - 0, - local_rank=0, - distributed_init_method=f"file://{temp_file}", - backend="gloo", - ) - # Under single worker mode, full model is init first and then - # partitioned using GSPMD. - ensure_model_parallel_initialized(1, 1) - return vllm_config - - -MESH = None - - -def _get_spmd_mesh(): - global MESH - if MESH is None: - xr.use_spmd() - num_devices = xr.global_runtime_device_count() - mesh_shape = (num_devices, 1) - device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) - return MESH - - -@pytest.mark.parametrize( - "model", - [ - "Qwen/Qwen2-1.5B-Instruct", - # Skip large models due to CI runner disk space limitations - # "meta-llama/Llama-3.1-8B-Instruct", - # "meta-llama/Llama-3.1-70B-Instruct", - ], -) -def test_tpu_model_loader(model): - # Skip the 70B test if there are less than 8 chips - # TODO: Query using torch xla API, the query API is not working - # with SPMD now. However, This test is running under SPMD mode. - if "70B" in model and xr.global_runtime_device_count() < 8: - pytest.skip( - "Skipping 70B model if the TPU VM has less than 8 chips to \ - avoid OOM." - ) - - vllm_config = _setup_environment(model) - loader = TPUModelLoader(load_config=vllm_config.load_config) - mesh = _get_spmd_mesh() - model = loader.load_model(vllm_config, vllm_config.model_config, mesh) - del model - gc.collect() diff --git a/tests/v1/tpu/test_topk_topp_sampler.py b/tests/v1/tpu/test_topk_topp_sampler.py deleted file mode 100644 index c6634395bb167..0000000000000 --- a/tests/v1/tpu/test_topk_topp_sampler.py +++ /dev/null @@ -1,149 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math - -import pytest -import torch -import torch_xla - -from vllm.platforms import current_platform -from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu - -if not current_platform.is_tpu(): - pytest.skip("This test needs a TPU.", allow_module_level=True) -import torch_xla.core.xla_model as xm - -BATCH_SIZE = 1024 -VOCAB_SIZE = 128 * 1024 -TOLERANCE = 1e-6 - - -def test_topk_equivalence_to_native_impl(): - with torch.device(xm.xla_device()): - xm.set_rng_state(seed=33) - - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) - - # Random top-k values between 1 and 10. - k = torch.randint(1, 10, (BATCH_SIZE,)) - - # Set k=vocab_size for ~50% of requests in the batch (top-k disabled). - k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE) - - result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None) - - result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None) - assert torch.allclose(result_native, result_tpu) - - -def test_topp_result_sums_past_p(): - with torch.device(xm.xla_device()): - xm.set_rng_state(seed=33) - - logits = torch.rand((BATCH_SIZE, VOCAB_SIZE)) - probs = logits.softmax(dim=-1) - - # Random top-p values between 0 and 1. - p = torch.rand((BATCH_SIZE,)) - - # Set p=1 for ~50% of requests in the batch (top-p disabled). - p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1) - - no_op_k = torch.tensor([VOCAB_SIZE]) - logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p) - - # Verify that the masked logit's probability sums to at least p. - probs.masked_fill_(logits_masked.isinf(), 0) - masked_prob_sum = probs.sum(dim=-1) - - torch_xla.sync() - - # Perform assertion on CPU. - assert torch.all(torch.ge(masked_prob_sum.cpu() + TOLERANCE, p.cpu())) - - -def test_topp_basic(): - with torch.device(xm.xla_device()): - logits = torch.tensor( - [ - [math.log(0.2), math.log(0.3), math.log(0.5)], - [math.log(0.5), math.log(0.1), math.log(0.4)], - ] - ) - - result = apply_top_k_top_p_tpu( - logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79]) - ) - - torch_xla.sync() - - # Expect the smallest elements to be dropped. - expected_result = logits.clone().cpu() - expected_result[0, 0] = float("-inf") - expected_result[1, 1] = float("-inf") - assert torch.allclose(expected_result, result.cpu()) - - -def test_topp_select_all(): - with torch.device(xm.xla_device()): - logits = torch.tensor( - [ - [math.log(0.2), math.log(0.3), math.log(0.5)], - [math.log(0.5), math.log(0.1), math.log(0.4)], - ] - ) - - result = apply_top_k_top_p_tpu( - logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0]) - ) - - torch_xla.sync() - - assert torch.allclose(logits.cpu(), result.cpu()) - - -def test_topp_with_ties(): - with torch.device(xm.xla_device()): - # Input has multiple math.log(0.3). - logits = torch.tensor( - [[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]] - ) - - result = apply_top_k_top_p_tpu( - logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2]) - ) - - torch_xla.sync() - - # All tie values are included in the top-p set. Tie breaking is left - # to be done during final sampling (all tie tokens have equal - # probability of being chosen). - expected_result = logits.clone().cpu() - expected_result[0, 3] = float("-inf") - assert torch.allclose(expected_result, result.cpu()) - - -def test_both_topk_topp(): - with torch.device(xm.xla_device()): - logits = torch.tensor( - [ - [math.log(0.2), math.log(0.3), math.log(0.5)], - [math.log(0.5), math.log(0.1), math.log(0.4)], - ] - ) - - # Set k=1 for the first batch. - result = apply_top_k_top_p_tpu( - logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79]) - ) - - torch_xla.sync() - - # Since for the first batch k=1, expect only the largest element gets - # selected. - expected_result = logits.clone().cpu() - expected_result[0, 0] = float("-inf") - expected_result[0, 1] = float("-inf") - expected_result[1, 1] = float("-inf") - assert torch.allclose(expected_result, result.cpu()) diff --git a/tests/v1/tpu/test_tpu_int8.py b/tests/v1/tpu/test_tpu_int8.py deleted file mode 100644 index 50001567a9588..0000000000000 --- a/tests/v1/tpu/test_tpu_int8.py +++ /dev/null @@ -1,78 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Tests whether TPU Int8 computation is enabled correctly. - -Run `pytest tests/quantization/test_tpu_int8.py`. -""" - -import pytest - -from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization.tpu_int8 import TPUInt8LinearMethod -from vllm.platforms import current_platform - -from ...models.registry import HF_EXAMPLE_MODELS - -MODELS = ["Qwen/Qwen2.5-0.5B-Instruct"] - - -@pytest.mark.skipif( - not current_platform.is_tpu(), reason="TPU Int8 is only enabled for TPUs." -) -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [10]) -@pytest.mark.parametrize( - "hf_overrides", - [ - # w8a8 dynamic activation - { - "quantization_config": { - "quant_method": "tpu_int8", - "activation_scheme": "dynamic", - } - } - ], -) -def test_model_tpu_int8( - vllm_runner, - model: str, - dtype: str, - max_tokens: int, - hf_overrides: dict, - monkeypatch, -) -> None: - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_transformers_version(on_fail="skip") - - activation_scheme = hf_overrides.get("quantization_config", {}).get( - "activation_scheme" - ) - quantize_activation = activation_scheme == "dynamic" - - # Allows using apply_model - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") - # Prevent error from re-initializing cache - monkeypatch.setenv("VLLM_XLA_CACHE_PATH", "") - - prompts = [ - "A robot may not injure a human being", - ] - answers = [ - "or kill a human being", - ] - - with vllm_runner(model, dtype=dtype, hf_overrides=hf_overrides) as vllm: - - def check_model(model): - for name, module in model.named_modules(): - if not isinstance(module, LinearBase): - continue - quant_method = module.quant_method - assert isinstance(quant_method, TPUInt8LinearMethod) - assert quant_method.quantize_activation == quantize_activation - - vllm.apply_model(check_model) - outputs = vllm.generate_greedy(prompts, max_tokens) - for (_, output), answer in zip(outputs, answers): - assert answer in output diff --git a/tests/v1/tpu/test_tpu_qkv_linear.py b/tests/v1/tpu/test_tpu_qkv_linear.py deleted file mode 100644 index 098d925505424..0000000000000 --- a/tests/v1/tpu/test_tpu_qkv_linear.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import tempfile - -import numpy as np -import pytest -import torch -import torch_xla.distributed.spmd as xs -import torch_xla.runtime as xr - -from vllm.config import set_current_vllm_config -from vllm.distributed.parallel_state import ( - ensure_model_parallel_initialized, - init_distributed_environment, -) -from vllm.distributed.tpu_distributed_utils import XlaQKVParallelLinear -from vllm.engine.arg_utils import EngineArgs -from vllm.model_executor.layers.linear import QKVParallelLinear - - -@pytest.fixture(autouse=True) -def setup_environment(): - # This is a fake config used for init dist env. - # QKVParallelLinear needs dist env to be initialized. - engine_args = EngineArgs( - model="Qwen/Qwen2-1.5B-Instruct", - max_model_len=64, - max_num_batched_tokens=64, - max_num_seqs=4, - ) - - vllm_config = engine_args.create_engine_config() - - with set_current_vllm_config(vllm_config): - temp_file = tempfile.mkstemp()[1] - init_distributed_environment( - 1, - 0, - local_rank=0, - distributed_init_method=f"file://{temp_file}", - backend="gloo", - ) - ensure_model_parallel_initialized(1, 1) - yield - - -MESH = None - - -def _get_spmd_mesh(): - global MESH - if MESH is None: - xr.use_spmd() - num_devices = xr.global_runtime_device_count() - mesh_shape = (num_devices, 1) - device_ids = np.array(range(num_devices)) - MESH = xs.Mesh(device_ids, mesh_shape, ("x", "y")) - return MESH - - -@pytest.mark.parametrize("bias", [False, True]) -# `xr.use_spmd()` will set a global state, and this state is not reversible. -# Therefore, non-SPMD tests should be run before SPMD tests. -@pytest.mark.parametrize("mesh", [None, _get_spmd_mesh()]) -@pytest.mark.parametrize("device", ["cpu", "xla"]) -@torch.no_grad() -def test_xla_qkv_linear(bias, mesh, device): - torch.manual_seed(123) - - qkv_linear = QKVParallelLinear( - hidden_size=4096, - head_size=128, - total_num_heads=32, - total_num_kv_heads=8, - bias=bias, - params_dtype=torch.bfloat16, - return_bias=False, - ) - - qkv_linear.weight.data = torch.rand_like(qkv_linear.weight.data) / 10 - if bias: - qkv_linear.bias.data = torch.rand_like(qkv_linear.bias.data) - - xla_qkv_linear = XlaQKVParallelLinear(qkv_linear, mesh=mesh) - - qkv_linear = qkv_linear.to(device) - xla_qkv_linear = xla_qkv_linear.to(device) - input_tensor = torch.rand(10, 4096, dtype=torch.bfloat16) / 10 - input_tensor = input_tensor.to(device) - - output = qkv_linear(input_tensor) - xla_output = xla_qkv_linear(input_tensor) - assert torch.allclose(output.cpu(), xla_output.cpu()) diff --git a/tests/v1/tpu/worker/__init__.py b/tests/v1/tpu/worker/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py deleted file mode 100644 index cfc06666e7984..0000000000000 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ /dev/null @@ -1,587 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.attention.layer import Attention -from vllm.config import ( - CacheConfig, - ModelConfig, - SchedulerConfig, - VllmConfig, - set_current_vllm_config, -) -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams -from vllm.utils.mem_constants import GiB_bytes -from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs -from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput -from vllm.v1.worker.tpu_model_runner import ( - TPUModelRunner, - _get_padded_num_reqs_with_upper_limit, - _get_padded_token_len, - _get_req_paddings, - _get_token_paddings, -) - - -def get_vllm_config(): - model_config = ModelConfig( - model="facebook/opt-125m", - dtype="bfloat16", # TPUs typically use bfloat16 - seed=42, - ) - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - is_encoder_decoder=model_config.is_encoder_decoder, - ) - cache_config = CacheConfig( - block_size=16, - gpu_memory_utilization=0.9, - swap_space=0, - cache_dtype="auto", - ) - vllm_config = VllmConfig( - model_config=model_config, - cache_config=cache_config, - scheduler_config=scheduler_config, - ) - return vllm_config - - -def get_model_runner(vllm_config): - device = "xla:0" # Mocking TPU device - return TPUModelRunner(vllm_config, device) - - -@pytest.fixture -def model_runner(): - # Patchers have already been started at module level. - vllm_config = get_vllm_config() - return get_model_runner(vllm_config) - - -def _schedule_new_request(*req_ids: str) -> SchedulerOutput: - new_reqs = [] - num_scheduled_tokens = {} - total_num_scheduled_tokens = 0 - for req_id in req_ids: - new_reqs.append( - NewRequestData( - req_id=req_id, - prompt_token_ids=[1, 2, 3], - mm_features=[], - sampling_params=SamplingParams(), - pooling_params=PoolingParams(), - block_ids=([0],), # block_ids should be tuple[list[int]] - num_computed_tokens=0, - lora_request=None, - ) - ) - num_scheduled_tokens[req_id] = 3 - total_num_scheduled_tokens += num_scheduled_tokens[req_id] - - return SchedulerOutput( - scheduled_new_reqs=new_reqs, - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens=num_scheduled_tokens, - total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids=set(), - free_encoder_mm_hashes=[], - ) - - -def _is_req_scheduled(model_runner, req_id: str) -> bool: - return req_id in model_runner.input_batch.req_id_to_index - - -def _is_req_added(model_runner, req_id: str) -> bool: - return req_id in model_runner.requests - - -def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: - """Check if the request state block IDs match the block table. - - This function handles both legacy BlockTable and new MultiGroupBlockTable - structures for backward compatibility. - """ - - req_index = model_runner.input_batch.req_id_to_index[req_id] - multi_group_block_table = model_runner.input_batch.block_table - req_state = model_runner.requests[req_id] - - # Access the first block table from MultiGroupBlockTable - # This is safe since we currently only use single KV cache groups - block_table = multi_group_block_table[0] - - # req_state.block_ids is now tuple[list[int], ...] for MultiGroupBlockTable - # Extract the first group's block IDs - if isinstance(req_state.block_ids[0], list): - # New format: tuple[list[int], ...] - extract first group - req_block_ids = req_state.block_ids[0] - else: - # Legacy format: list[int] - use directly - req_block_ids = req_state.block_ids - - if block_table.num_blocks_per_row[req_index] != len(req_block_ids): - return False - - num_blocks = block_table.num_blocks_per_row[req_index] - block_table_values = block_table.block_table.np[req_index, :num_blocks] - return (block_table_values == req_block_ids).all() - - -def test_update_states_new_request(model_runner): - req_id = "req_0" - - # new req - scheduler_output = _schedule_new_request(req_id) - - model_runner._update_states(scheduler_output) - - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - assert _is_req_state_block_table_match(model_runner, req_id) - - -def test_update_states_request_finished(model_runner): - req_id = "req_0" - - # new req - scheduler_output = _schedule_new_request(req_id) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - - # finish req - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={}, - total_num_scheduled_tokens=0, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids={req_id}, - free_encoder_mm_hashes=[], - ) - - model_runner._update_states(scheduler_output) - assert not _is_req_added(model_runner, req_id) - assert not _is_req_scheduled(model_runner, req_id) - - -def test_update_states_request_resumed(model_runner): - req_id = "req_0" - - # new req - scheduler_output = _schedule_new_request(req_id) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - - # unschedule req - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={}, - total_num_scheduled_tokens=0, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids=set(), - free_encoder_mm_hashes=[], - ) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert not _is_req_scheduled(model_runner, req_id) - - # resume req - cached_req_data = CachedRequestData( - req_ids=[req_id], - resumed_req_ids={req_id}, - new_token_ids=[[]], - all_token_ids={req_id: scheduler_output.scheduled_new_reqs[0].prompt_token_ids}, - new_block_ids=[([],)], - num_computed_tokens=[0], - num_output_tokens=[0], - ) - - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=cached_req_data, - num_scheduled_tokens={req_id: 1}, - total_num_scheduled_tokens=1, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids=set(), - free_encoder_mm_hashes=[], - ) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - assert _is_req_state_block_table_match(model_runner, req_id) - - -def test_update_states_no_changes(model_runner): - req_id = "req_0" - - # new req - scheduler_output = _schedule_new_request(req_id) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - - # schedule req - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_id: 1}, - total_num_scheduled_tokens=1, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids=set(), - free_encoder_mm_hashes=[], - ) - - model_runner._update_states(scheduler_output) - assert _is_req_added(model_runner, req_id) - assert _is_req_scheduled(model_runner, req_id) - assert _is_req_state_block_table_match(model_runner, req_id) - - -def test_update_states_request_unscheduled(model_runner): - req_ids = ("req_0", "req_1") - - # new reqs - scheduler_output = _schedule_new_request(*req_ids) - - model_runner._update_states(scheduler_output) - - assert _is_req_added(model_runner, req_ids[0]) - assert _is_req_scheduled(model_runner, req_ids[0]) - - assert _is_req_added(model_runner, req_ids[1]) - assert _is_req_scheduled(model_runner, req_ids[1]) - - # unschedule req_1 - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=CachedRequestData.make_empty(), - num_scheduled_tokens={req_ids[0]: 1}, - total_num_scheduled_tokens=1, - scheduled_spec_decode_tokens={}, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=[], - finished_req_ids=set(), - free_encoder_mm_hashes=[], - ) - - model_runner._update_states(scheduler_output) - - assert _is_req_added(model_runner, req_ids[0]) - assert _is_req_scheduled(model_runner, req_ids[0]) - - assert _is_req_added(model_runner, req_ids[1]) - assert not _is_req_scheduled(model_runner, req_ids[1]) - - -def test_get_paddings(): - # Bucketed padding - min_token_size, max_token_size, padding_gap = 16, 512, 64 - expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) - - # Bucketed padding with max_token_size not a power of two. - max_token_size = 317 - expected_paddings = [16, 32, 64, 128, 192, 256, 320] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) - assert actual_paddings == expected_paddings - - # Exponential padding. - max_token_size, padding_gap = 1024, 0 - expected_paddings = [16, 32, 64, 128, 256, 512, 1024] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) - assert actual_paddings == expected_paddings - # Exponential padding with max_token_size not a power of two. - max_token_size = 317 - expected_paddings = [16, 32, 64, 128, 256, 512] - actual_paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) - assert actual_paddings == expected_paddings - - -def test_get_padded_token_len(): - min_token_size, max_token_size, padding_gap = 16, 512, 64 - paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) - assert _get_padded_token_len(paddings, 1) == 16 - assert _get_padded_token_len(paddings, 16) == 16 - assert _get_padded_token_len(paddings, 20) == 32 - assert _get_padded_token_len(paddings, 300) == 320 - assert _get_padded_token_len(paddings, 512) == 512 - - -def test_get_padded_num_reqs_with_upper_limit(): - assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8 - assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16 - assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32 - assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28 - - -def test_get_req_paddings(): - assert _get_req_paddings(1, 32) == [8, 16, 32] - assert _get_req_paddings(8, 32) == [8, 16, 32] - assert _get_req_paddings(8, 36) == [8, 16, 32, 36] - - -def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(model_runner): - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - error_msg = f"{layer_1} must come before the current layer" - vllm_config = model_runner.vllm_config - with ( - pytest.raises(ValueError, match=error_msg), - set_current_vllm_config(vllm_config), - ): - fwd_context = { - # initialization below will fail because target layer is invalid; - # the target layer needs to come before layer 1 - layer_0: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_0, - kv_sharing_target_layer_name=layer_1, - ), - layer_1: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_1, - ), - } - # suppress var not used error - assert fwd_context is not None - - -def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(model_runner): - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - invalid_layer = "model.layers.0.cross_attn.attn" - error_msg = f"{invalid_layer} is not a valid Attention layer in the model" - vllm_config = model_runner.vllm_config - with ( - pytest.raises(ValueError, match=error_msg), - set_current_vllm_config(vllm_config), - ): - fwd_context = { - layer_0: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_0, - ), - layer_1: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_1, - # invalid layer: cross_attn.atn doesn't exist! - kv_sharing_target_layer_name=invalid_layer, - ), - } - # suppress var not used error - assert fwd_context is not None - - -def test_init_kv_cache_with_kv_sharing_target_same_as_current(model_runner): - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - error_msg = f"{layer_1} cannot be the same as the current layer" - vllm_config = model_runner.vllm_config - with ( - pytest.raises(ValueError, match=error_msg), - set_current_vllm_config(vllm_config), - ): - fwd_context = { - # initialization below will fail because target layer is invalid; - # the target layer needs to come before layer 1 - layer_0: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_0, - ), - layer_1: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_1, - kv_sharing_target_layer_name=layer_1, - ), - } - # suppress var not used error - assert fwd_context is not None - - -def test_init_kv_cache_without_kv_sharing(): - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - vllm_config = get_vllm_config() - with set_current_vllm_config(vllm_config): - fwd_context = { - layer_0: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_0, - ), - layer_1: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_1, - ), - } - # suppress var not used error - assert fwd_context is not None - # Set high context length to test max context length estimation - vllm_config.model_config.max_model_len = 1_000_000 - vllm_ctx = vllm_config.compilation_config.static_forward_context - model_runner = get_model_runner(vllm_config) - kv_cache_spec = model_runner.get_kv_cache_spec() - assert len(kv_cache_spec) == 2 - assert len(model_runner.shared_kv_cache_layers) == 0 - - available_memory = 20 * GiB_bytes - # page size for each layer KV can be calculated as - # 2 (non-MLA) * 8 (num_heads) * 128 (head_dim) - # * 2 (bfloat16, kv_cache dtype) * 128 (block_size) = 512KB - num_expected_blocks = 20480 # 20GB / 512KB / 2 (num layers) - kv_cache_config = get_kv_cache_configs( - vllm_config, [kv_cache_spec], [available_memory] - )[0] - assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.kv_cache_tensors) == 2 - assert kv_cache_config.kv_cache_tensors[0].size == available_memory // 2 - assert kv_cache_config.kv_cache_tensors[1].size == available_memory // 2 - - max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) - # max context len with KV sharing should be 2x as large as without - # max_context_len = available_memory / (page_size / block_size) / num_caches - # max_context_len = 5GB / (512KB / 128) / 2 = 655360 - assert max_context_len == 655360 - - # important: override tensor size to prevent large mem alloc during test - # this will only allocate 2 block worth of memory (2 * 512kb) - kv_cache_config.num_blocks = 1 - for kv_cache_tensor in kv_cache_config.kv_cache_tensors: - kv_cache_tensor.size = kv_cache_spec[ - kv_cache_tensor.shared_by[0] - ].page_size_bytes - - model_runner.initialize_kv_cache(kv_cache_config) - - layer_0_kv = vllm_ctx[layer_0].kv_cache[0] - layer_1_kv = vllm_ctx[layer_1].kv_cache[0] - # check layer 1 kv cache does NOT share memory with layer 0 - assert id(layer_1_kv) != id(layer_0_kv) - - # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 - - -def test_init_kv_cache_with_kv_sharing_valid(): - layer_0 = "model.layers.0.self_attn.attn" - layer_1 = "model.layers.1.self_attn.attn" - vllm_config = get_vllm_config() - with set_current_vllm_config(vllm_config): - fwd_context = { - layer_0: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_0, - ), - layer_1: Attention( - num_heads=8, - head_size=128, - scale=1.0, - prefix=layer_1, - kv_sharing_target_layer_name="model.layers.0.self_attn.attn", - ), - } - # suppress var not used error - assert fwd_context is not None - # Set high context length to test max context length estimation - vllm_config.model_config.max_model_len = 3_000_000 - vllm_ctx = vllm_config.compilation_config.static_forward_context - model_runner = get_model_runner(vllm_config) - kv_cache_spec = model_runner.get_kv_cache_spec() - assert len(kv_cache_spec) == 1 - assert layer_0 in kv_cache_spec - assert model_runner.shared_kv_cache_layers[layer_1] == layer_0 - - available_memory = 20 * GiB_bytes - # page size for layer 0's kv_cache_spec is 512KB - # with KV sharing, we can allocate (available_mem//page_size//1) blocks - # which is twice as many as without KV sharing - num_expected_blocks = 2 * 20480 # 20GB / 512KB - kv_cache_config = get_kv_cache_configs( - vllm_config, [kv_cache_spec], [available_memory] - )[0] - assert kv_cache_config.num_blocks == num_expected_blocks - assert len(kv_cache_config.kv_cache_tensors) == 1 - # Each layer now has twice the available memory for KV cache - # compared to no KV sharing - assert kv_cache_config.kv_cache_tensors[0].size == available_memory - - max_context_len = estimate_max_model_len(vllm_config, kv_cache_spec, 5 * GiB_bytes) - # max context len with KV sharing should be 2x as large as without - assert max_context_len == (2 * 655360) - - # important: override tensor size to prevent large mem alloc during test - # this will only allocate 1 block worth of memory (512kb) - kv_cache_config.num_blocks = 1 - kv_cache_config.kv_cache_tensors[0].size = kv_cache_spec[layer_0].page_size_bytes - - model_runner.initialize_kv_cache(kv_cache_config) - - layer_0_kv = vllm_ctx[layer_0].kv_cache[0] - layer_1_kv = vllm_ctx[layer_1].kv_cache[0] - # check layer 1 kv cache shares memory with layer 0 - assert id(layer_1_kv) == id(layer_0_kv) - - # check layer 1 added to kv cache group's layer names - assert len(kv_cache_config.kv_cache_groups) == 1 - assert len(kv_cache_config.kv_cache_groups[0].layer_names) == 2 - assert kv_cache_config.kv_cache_groups[0].layer_names[0] == layer_0 - assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1 - - -def test_most_model_len(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("VLLM_TPU_MOST_MODEL_LEN", "2048") - vllm_config = get_vllm_config() - vllm_config.model_config.max_model_len = 32000 - vllm_config.scheduler_config.max_num_seqs = 1200 - model_runner = get_model_runner(vllm_config) - - # verify model runner will adjust num_reqs to avoid SMEM OOM. - assert model_runner.num_reqs_most_model_len == 1200 - # num_page_per_req = 32k // 128 - # num_reqs = 1024 ** 2 // 2 // num_page_per_req // 4 = 524 - assert model_runner.num_reqs_max_model_len == 524 From df62da8da2dea0c45558f7f044aff4e74625a963 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Tue, 16 Dec 2025 01:48:51 +0000 Subject: [PATCH 04/11] Remove tpu_int8 as it is related to deleted quantization config and implementation Signed-off-by: Wei-Yu Lin --- vllm/model_executor/layers/quantization/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 1a4378f5df3db..48db0d1bbbd47 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -11,7 +11,6 @@ logger = init_logger(__name__) QuantizationMethods = Literal[ "awq", "deepspeedfp", - "tpu_int8", "fp8", "ptpc_fp8", "fbgemm_fp8", From 6125a172f93d84c8aff9ece20517ec874aaef377 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Wed, 17 Dec 2025 01:10:49 +0000 Subject: [PATCH 05/11] Remove pallas.py as this is migrate to tpu-inference Signed-off-by: Wei-Yu Lin --- vllm/v1/attention/backends/pallas.py | 316 --------------------------- 1 file changed, 316 deletions(-) delete mode 100644 vllm/v1/attention/backends/pallas.py diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py deleted file mode 100644 index e5a0cf7420497..0000000000000 --- a/vllm/v1/attention/backends/pallas.py +++ /dev/null @@ -1,316 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from dataclasses import dataclass - -import torch - -from vllm.attention.backends.abstract import ( - AttentionBackend, - AttentionImpl, - AttentionLayer, - AttentionType, -) -from vllm.config import VllmConfig -from vllm.logger import init_logger -from vllm.utils.math_utils import cdiv, next_power_of_2 - -logger = init_logger(__name__) - -# TPU requires the head size to be a multiple of 128. -TPU_HEAD_SIZE_ALIGNMENT = 128 - -# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8 -# from to fp32 directly. That's why it has a dtype mapping different from GPU -TPU_STR_DTYPE_TO_TORCH_DTYPE = { - "half": torch.half, - "bfloat16": torch.bfloat16, - "float": torch.float, - "fp8": torch.float8_e4m3fn, - "fp8_e4m3": torch.float8_e4m3fn, - "fp8_e5m2": torch.float8_e5m2, - "int8": torch.int8, - "uint8": torch.uint8, -} - -import tpu_inference # noqa: F401 - -# Note(weiyulin): some static functions are still used by tpu-inference -class PallasAttentionBackend(AttentionBackend): - @staticmethod - def get_name() -> str: - return "PALLAS" - - @staticmethod - def get_impl_cls() -> type["PallasAttentionBackendImpl"]: - return PallasAttentionBackendImpl - - @staticmethod - def get_kv_cache_shape( - num_blocks: int, - block_size: int, - num_kv_heads: int, - head_size: int, - cache_dtype_str: str = "auto", - ) -> tuple[int, ...]: - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - return (num_blocks, block_size, num_kv_heads * 2, padded_head_size) - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: torch.Tensor, - ) -> None: - raise RuntimeError("swap_blocks is not used for the TPU backend.") - - # In recent TPU generations, up to v6e, the SMEM size is 1MB. The - # block_tables within the PallasMetadata constitute almost the entire SMEM - # requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here - # we simply make sure that the size is smaller than half of SMEM capacity. - @staticmethod - def get_min_page_size(vllm_config: VllmConfig) -> int: - max_num_page_per_req = ( - 1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4 - ) - min_page_size = cdiv( - vllm_config.model_config.max_model_len, max_num_page_per_req - ) - min_page_size = 1 << (min_page_size - 1).bit_length() - return min_page_size - - @staticmethod - def get_max_num_seqs(model_len: int, page_size: int) -> int: - num_page_per_req = cdiv(model_len, page_size) - return 1024 * 1024 // 2 // num_page_per_req // 4 - - # TPU has limited SREGs (scalar registers), if page_size is too small, we - # can spill SREGs easily which leads to bad performance. The strategy we - # apply here is trying to split max-model-len to 16 pages which make the - # spill less likely. Meanwhile we make sure the page size is in [16, 256]. - @staticmethod - def get_page_size(vllm_config: VllmConfig) -> int: - # TODO: This is a temporary fix for vmem OOM. - # For long model length, we use 16 page-size to avoid too much - # VMEM spill. A more robust solution should be implemented to - # handle VREG spills. - if vllm_config.model_config.max_model_len > 8192: - return 16 - page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16 - if page_size <= 16: - return 16 - if page_size >= 256: - return 256 - return page_size - - -@dataclass -class PallasMetadata: - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - # Used in the PallasAttentionBackendImpl - slot_mapping: torch.Tensor - block_tables: torch.Tensor - context_lens: torch.Tensor - query_start_loc: torch.Tensor - num_seqs: torch.Tensor - num_kv_update_slices: torch.Tensor - num_slices_per_kv_cache_update_block: int - - -class PallasAttentionBackendImpl(AttentionImpl): - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: list[float] | None, - sliding_window: int | None, - kv_cache_dtype: str, - logits_soft_cap: float | None = None, - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: int | None = None, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.sliding_window = sliding_window - self.logits_soft_cap = logits_soft_cap - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - if alibi_slopes is not None: - raise NotImplementedError("Alibi slopes is not supported.") - - if attn_type != AttentionType.DECODER: - raise NotImplementedError( - "Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl" - ) - - self.kv_cache_quantized_dtype = None - if kv_cache_dtype != "auto": - self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get( - kv_cache_dtype.lower().strip() - ) - - def forward( - self, - layer: AttentionLayer, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: PallasMetadata, - output: torch.Tensor | None = None, - output_scale: torch.Tensor | None = None, - output_block_scale: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with Pallas attention. - - Args: - query: shape = [num_tokens, num_heads * head_size] - key: shape = [num_tokens, num_kv_heads * head_size] - value: shape = [num_tokens, num_kv_heads * head_size] - kv_cache: shape = - [num_blocks, block_size, num_kv_heads * 2, head_size] - attn_metadata: Metadata for attention. - Returns: - shape = [num_tokens, num_heads * head_size] - """ - if output_scale is not None or output_block_scale is not None: - raise NotImplementedError( - "fused output quantization is not yet supported" - " for PallasAttentionBackendImpl" - ) - - # For determine_available_memory case. - if kv_cache.numel() == 0: - if output is None: - output = torch.ones_like(query) - return output - - num_tokens, hidden_size = query.shape - query = query.view(num_tokens, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - padded_head_size = ( - cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - query = torch.nn.functional.pad( - query, (0, padded_head_size - self.head_size), value=0.0 - ) - key = torch.nn.functional.pad( - key, (0, padded_head_size - self.head_size), value=0.0 - ) - value = torch.nn.functional.pad( - value, (0, padded_head_size - self.head_size), value=0.0 - ) - - if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0: - # Write input keys and values to the KV cache. - # Skip this if sharing KV cache with an earlier attention layer. - slot_mapping = attn_metadata.slot_mapping - write_to_kv_cache( - key, - value, - kv_cache, - slot_mapping, - attn_metadata.num_slices_per_kv_cache_update_block, - attn_metadata.num_kv_update_slices, - self.kv_cache_quantized_dtype, - layer._k_scale_float, - layer._v_scale_float, - ) - - if self.kv_cache_quantized_dtype is not None and ( - layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0 - ): - raise ValueError("k_scale_float and v_scale_float must be non-zero") - output = torch.ops.xla.ragged_paged_attention( - query, - kv_cache, - attn_metadata.context_lens, - attn_metadata.block_tables, - attn_metadata.query_start_loc, - attn_metadata.num_seqs, - # By default, the system utilizes optimized block size and - # vmem_limit_bytes parameters from the kernel repository. However, - # these can be manually adjusted for debugging if necessary. - num_kv_pages_per_block=None, - num_queries_per_block=None, - vmem_limit_bytes=None, - use_kernel=True, - sm_scale=self.scale, - sliding_window=self.sliding_window, - soft_cap=self.logits_soft_cap, - k_scale=layer._k_scale_float, - v_scale=layer._v_scale_float, - ) - - if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0: - output = output[:, :, : self.head_size] - - return output.reshape(num_tokens, hidden_size) - - -def write_to_kv_cache( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - num_slices_per_kv_cache_update_block: int, - num_kv_update_slices: torch.Tensor, - kv_cache_quantized_dtype: torch.dtype | None = None, - k_scale: float = 1.0, - v_scale: float = 1.0, -) -> None: - """Write the key and values to the KV cache. - - Args: - key: shape = [num_tokens, num_kv_heads, head_size] - value: shape = [num_tokens, num_kv_heads, head_size] - kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size] - num_slices_per_kv_cache_update_block: int - """ - _, page_size, num_combined_kv_heads, head_size = kv_cache.shape - head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - - if kv_cache_quantized_dtype is not None: - dtype_info = torch.finfo(kv_cache_quantized_dtype) - key = key.to(torch.float32) / k_scale - # NOTE: clamp is added here to avoid out of range of quantized dtype - key = torch.clamp(key, dtype_info.min, dtype_info.max) - key = key.to(kv_cache_quantized_dtype) - value = value.to(torch.float32) / v_scale - value = torch.clamp(value, dtype_info.min, dtype_info.max) - value = value.to(kv_cache_quantized_dtype) - - kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size) - - torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True) - - kv_cache = kv_cache.flatten(0, 1) - new_kv_cache = torch.ops.xla.kv_cache_update_op( - kv, - slot_mapping, - kv_cache, - num_kv_update_slices, - page_size, - num_slices_per_kv_cache_update_block, - ) - # NOTE: the in-place copy will be optimized away by XLA compiler. - kv_cache.copy_(new_kv_cache) From 603a1bf9bcd35f0a329f2cbe45123e6e26bbe9a8 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Wed, 17 Dec 2025 01:26:42 +0000 Subject: [PATCH 06/11] Run pre-commit to format files Signed-off-by: Wei-Yu Lin --- .../model_loader/default_loader.py | 2 +- vllm/platforms/tpu.py | 30 ++----------------- vllm/v1/worker/tpu_worker.py | 27 +---------------- 3 files changed, 5 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 4d85f8e3b478c..deee2324960dd 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -242,7 +242,7 @@ class DefaultModelLoader(BaseModelLoader): ) if current_platform.is_tpu(): - from vllm.platforms.tpu import USE_TPU_INFERENCE + pass if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index f7a11d2c557c4..455aceb3269eb 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -1,34 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -from typing import TYPE_CHECKING, Optional, cast - -import torch -from tpu_info import device - -from vllm.attention.backends.registry import AttentionBackendEnum -from vllm.inputs import ProcessorInputs, PromptType from vllm.logger import init_logger -from .interface import Platform, PlatformEnum - -if TYPE_CHECKING: - from typing import TypeAlias - - from vllm.attention.selector import AttentionSelectorConfig - from vllm.config import VllmConfig - from vllm.config.cache import BlockSize - from vllm.pooling_params import PoolingParams - from vllm.sampling_params import SamplingParams - - ParamsType: TypeAlias = SamplingParams | PoolingParams -else: - BlockSize = None - VllmConfig = None - PoolingParams = None - ParamsType = None - logger = init_logger(__name__) @@ -40,5 +14,7 @@ try: TpuPlatform = TpuInferencePlatform # type: ignore USE_TPU_INFERENCE = True except ImportError: - logger.error("tpu_inference not found, please install tpu_inference to run vllm on TPU") + logger.error( + "tpu_inference not found, please install tpu_inference to run vllm on TPU" + ) pass diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index b50def0e17de4..085b119e12600 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -2,35 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" -import os -from collections.abc import Callable -from typing import Any, TypeVar +from typing import TypeVar -import torch -import torch.nn as nn - -import vllm.envs as envs -from vllm.config import VllmConfig, set_current_vllm_config -from vllm.distributed import ( - ensure_model_parallel_initialized, - init_distributed_environment, -) -from vllm.distributed.kv_transfer import ( - ensure_kv_transfer_initialized, -) from vllm.logger import init_logger -from vllm.lora.request import LoRARequest -from vllm.model_executor import set_random_seed -from vllm.platforms import current_platform from vllm.platforms.tpu import USE_TPU_INFERENCE -from vllm.tasks import SupportedTask -from vllm.utils.math_utils import cdiv -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput -from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) From b280f38b15e7f12034d2e1d1fb8a57d1c95a940f Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Thu, 18 Dec 2025 20:07:42 +0000 Subject: [PATCH 07/11] Add TODO and remove unused codepath Signed-off-by: Wei-Yu Lin --- vllm/model_executor/model_loader/default_loader.py | 3 --- vllm/v1/worker/tpu_worker.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index deee2324960dd..5d24e72919693 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -241,9 +241,6 @@ class DefaultModelLoader(BaseModelLoader): self.load_config.pt_load_map_location, ) - if current_platform.is_tpu(): - pass - if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 085b119e12600..4c73d6c92d391 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -11,6 +11,7 @@ logger = init_logger(__name__) _R = TypeVar("_R") +# TODO(weiyulin) Remove this file after adding an official way to use hardware plugin if USE_TPU_INFERENCE: from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker From fb58f7fd5b9933873269a589f37d86fb4483a4a4 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Thu, 18 Dec 2025 20:31:55 +0000 Subject: [PATCH 08/11] Run pre-commit to fix format error Signed-off-by: Wei-Yu Lin --- vllm/model_executor/layers/fused_moe/layer.py | 7 +++---- vllm/model_executor/model_loader/default_loader.py | 1 - 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index d6226da76eaed..559f1a87d9777 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -67,16 +67,15 @@ else: eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk -from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 - rocm_aiter_grouped_topk, -) - from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import ( FusedMoEModularMethod, ) +from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 + rocm_aiter_grouped_topk, +) from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 5d24e72919693..c4e961581ef3f 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import ( pt_weights_iterator, safetensors_weights_iterator, ) -from vllm.platforms import current_platform from vllm.transformers_utils.repo_utils import list_filtered_repo_files logger = init_logger(__name__) From decf3e69bcd5cd6c10340a6a45c77dde9a586ae3 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Thu, 18 Dec 2025 22:42:02 +0000 Subject: [PATCH 09/11] Remove MOE xla implementation Signed-off-by: Wei-Yu Lin --- docs/design/moe_kernel_features.md | 1 - vllm/attention/layers/mm_encoder_attention.py | 25 ------------------- 2 files changed, 26 deletions(-) diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 6c02dcb76bec2..11c6e488f958f 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | -| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | | iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index 1c1623b13f55a..138fc99114127 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp): "XPU only supports FLASH_ATTN for vision attention." ) return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) - - def forward_tpu( - self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention - ) -> torch.Tensor: - assert self.attn_backend == AttentionBackendEnum.PALLAS, ( - f"MMEncoderAttention on TPU only supports PALLAS backend, " - f"but got {self.attn_backend}." - ) - if cu_seqlens is None: - query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - from torch_xla.experimental.custom_kernel import flash_attention - - out = flash_attention(query, key, value, sm_scale=self.scale) - out = out.transpose(1, 2) - return out - logger.warning_once( - "PALLAS backend with cu_seqlens is not supported for ViT yet. ", - "Falling back to SDPA implementation.", - ) - return self._forward_sdpa(query, key, value, cu_seqlens) From 0f7ee9d24730eb805d575e16b9e4cc3b19056a60 Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Fri, 19 Dec 2025 23:38:53 +0000 Subject: [PATCH 10/11] Remove unused pallas registration Signed-off-by: Wei-Yu Lin --- vllm/attention/backends/registry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/attention/backends/registry.py b/vllm/attention/backends/registry.py index 416b996df9f22..77724a3a1915c 100644 --- a/vllm/attention/backends/registry.py +++ b/vllm/attention/backends/registry.py @@ -66,7 +66,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" ) FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" - PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend" IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" From 9aaed80cc85f65438f859bdd19fe90d6b712be5c Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Mon, 22 Dec 2025 23:42:26 +0000 Subject: [PATCH 11/11] Remove _use_pallas var as PALLAS attention backend is deprecated Signed-off-by: Wei-Yu Lin --- vllm/distributed/kv_transfer/kv_connector/utils.py | 6 +----- .../kv_transfer/kv_connector/v1/mooncake_connector.py | 1 - .../kv_transfer/kv_connector/v1/nixl_connector.py | 8 +------- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 4f1ea1a0240c4..914ab91b1563c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.registry import AttentionBackendEnum from vllm.config import get_current_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger @@ -251,9 +250,6 @@ class TpKVTopology: len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 ) - attn_backend = AttentionBackendEnum[self.attn_backend.get_name()] - self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS - @property def is_kv_layout_blocks_first(self) -> bool: return self._is_kv_layout_blocks_first @@ -261,7 +257,7 @@ class TpKVTopology: @property def split_k_and_v(self) -> bool: # Whether to register regions for K and V separately (when present). - return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) + return not (self.is_mla or self.is_kv_layout_blocks_first) @property def tp_size(self) -> int: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py index 9a15d3fa6ed09..38ce02a2fef76 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -499,7 +499,6 @@ class MooncakeConnectorWorker: total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backend=backend, ) - self._use_pallas = self.kv_topo._use_pallas self.zmq_ctx = zmq.Context() self.async_zmq_ctx = zmq.asyncio.Context() diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 757ca41e9844b..0f33cde7d3221 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -983,7 +983,6 @@ class NixlConnectorWorker: total_num_kv_heads=self.model_config.get_total_num_kv_heads(), attn_backend=backend, ) - self._use_pallas = self.kv_topo._use_pallas self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( @@ -1641,9 +1640,6 @@ class NixlConnectorWorker: # Num kv_heads > tp_size and P TP > D TP case, not supported assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) - assert not self._use_pallas or tp_ratio == 1, ( - "TPU (pallas_v1) DOES NOT support heterogeneous TP yet." - ) kv_cache_layout = ( self.kv_cache_layout if not self.use_host_buffer @@ -1814,9 +1810,7 @@ class NixlConnectorWorker: if len(self.device_kv_caches) == 0: return - split_k_and_v = not ( - self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first - ) + split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) sample_cache = list(self.device_kv_caches.values())[0][0] for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): assert block_size_ratio > 1, "Only nP < nD supported currently."