From d5cab6f65c320c383fcd1521431002970715729f Mon Sep 17 00:00:00 2001 From: Wei-Yu Lin Date: Thu, 11 Dec 2025 19:59:02 +0000 Subject: [PATCH] Remove tpu_inference fall back logic Signed-off-by: Wei-Yu Lin --- .../device_communicators/tpu_communicator.py | 14 - .../model_loader/default_loader.py | 13 - vllm/platforms/tpu.py | 253 +------------- vllm/v1/attention/backends/pallas.py | 65 +--- vllm/v1/worker/tpu_worker.py | 310 ------------------ 5 files changed, 2 insertions(+), 653 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index fa99078e9ff0d..9581a3dbc7b74 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -19,20 +19,6 @@ USE_RAY = parallel_config = ( logger = init_logger(__name__) -if not USE_TPU_INFERENCE: - logger.info("tpu_inference not found, using vLLM's TpuCommunicator") - if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups, - ) - - if USE_RAY: - from vllm.v1.executor import ray_utils - class TpuCommunicator(DeviceCommunicatorBase): def __init__( diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 88c6d1e27e39c..4d85f8e3b478c 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -244,19 +244,6 @@ class DefaultModelLoader(BaseModelLoader): if current_platform.is_tpu(): from vllm.platforms.tpu import USE_TPU_INFERENCE - if not USE_TPU_INFERENCE: - # In PyTorch XLA, we should call `torch_xla.sync` - # frequently so that not too many ops are accumulated - # in the XLA program. - import torch_xla - - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - torch_xla.sync(wait=False) - - weights_iterator = _xla_weights_iterator(weights_iterator) - if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 7c479bf2b6a0e..f7a11d2c557c4 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -31,257 +31,6 @@ else: logger = init_logger(__name__) -USE_TPU_INFERENCE = False - - -class TpuPlatform(Platform): - _enum = PlatformEnum.TPU - device_name: str = "tpu" - device_type: str = "tpu" - dispatch_key: str = "XLA" - ray_device_key: str = "TPU" - dist_backend: str = "gloo" - device_control_env_var: str = "TPU_VISIBLE_CHIPS" - simple_compile_backend: str = "openxla" - - supported_quantization: list[str] = ["fp8", "tpu_int8", "compressed-tensors"] - - additional_env_vars: list[str] = ["TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"] - - @classmethod - def import_kernels(cls) -> None: - # Do not import vllm._C - with contextlib.suppress(ImportError): - import vllm._moe_C # noqa: F401 - - @classmethod - def get_attn_backend_cls( - cls, - selected_backend: "AttentionBackendEnum", - attn_selector_config: "AttentionSelectorConfig", - ) -> str: - if attn_selector_config.use_sparse: - raise NotImplementedError("Sparse Attention is not supported on TPU.") - if selected_backend != AttentionBackendEnum.PALLAS: - logger.info("Cannot use %s backend on TPU.", selected_backend) - - logger.info("Using Pallas V1 backend.") - return AttentionBackendEnum.PALLAS.get_path() - - @classmethod - def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: - return [ - AttentionBackendEnum.PALLAS, - ] - - @classmethod - def get_vit_attn_backend( - cls, - head_size: int, - dtype: torch.dtype, - backend: Optional["AttentionBackendEnum"] = None, - ) -> "AttentionBackendEnum": - if backend is not None: - assert backend in cls.get_supported_vit_attn_backends(), ( - f"Backend {backend} is not supported for vit attention" - f"Supported backends are: {cls.get_supported_vit_attn_backends()}." - ) - logger.info_once(f"Using backend {backend} for vit attention.") - return backend - - logger.info_once( - f"Using default backend {AttentionBackendEnum.PALLAS} for vit attention." - ) - return AttentionBackendEnum.PALLAS - - @classmethod - def set_device(cls, device: torch.device) -> None: - """ - Set the device for the current platform. - """ - torch.tpu.set_device(device) - - @classmethod - def get_device_name(cls, device_id: int = 0) -> str: - chip_type, _ = device.get_local_chips() - return f"TPU {chip_type.name}" - - @classmethod - def get_device_total_memory(cls, device_id: int = 0) -> int: - raise NotImplementedError - - @classmethod - def get_punica_wrapper(cls) -> str: - return "vllm.lora.punica_wrapper.punica_tpu.PunicaWrapperTPU" - - @classmethod - def get_infinity_values(cls, dtype: torch.dtype) -> tuple[float, float]: - return torch.finfo(dtype).min, torch.finfo(dtype).max - - @classmethod - def can_update_inplace(cls): - return False - - @classmethod - def get_lora_vocab_padding_size(cls) -> int: - return 1 - - @classmethod - def inference_mode(cls): - return torch.no_grad() - - @classmethod - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - from vllm.config import CompilationMode, CUDAGraphMode - - cache_config = vllm_config.cache_config - # For v0, the default block size is 16. - if cache_config and cache_config.block_size is None: - cache_config.block_size = cast(BlockSize, 16) - compilation_config = vllm_config.compilation_config - - # TPU only supports DYNAMO_TRACE_ONCE compilation mode - if compilation_config.mode != CompilationMode.DYNAMO_TRACE_ONCE: - logger.info( - "[TPU] Forcing DYNAMO_TRACE_ONCE compilation mode, and\ - disabling cudagraph." - ) - compilation_config.mode = CompilationMode.DYNAMO_TRACE_ONCE - - if ( - compilation_config.cudagraph_mode is None - or compilation_config.cudagraph_mode.max_cudagraph_mode() - != CUDAGraphMode.NONE - ): - logger.info( - "[TPU] CUDA graph is not supported on TPU, disabling cudagraphs." - ) - compilation_config.cudagraph_mode = CUDAGraphMode.NONE - - if compilation_config.backend == "": - compilation_config.backend = "openxla" - - assert vllm_config.speculative_config is None, ( - "TPU does not support speculative decoding" - ) - - model_config = vllm_config.model_config - if model_config is not None and model_config.dtype in ( - torch.float16, - torch.float32, - ): - logger.warning( - "The TPU backend currently does not support %s. " - "Using bfloat16 instead.", - model_config.dtype, - ) - model_config.dtype = torch.bfloat16 - - from vllm.v1.attention.backends.pallas import PallasAttentionBackend - - cache_config.block_size = PallasAttentionBackend.get_page_size(vllm_config) # type: ignore[assignment] - - parallel_config = vllm_config.parallel_config - scheduler_config = vllm_config.scheduler_config - if parallel_config.worker_cls == "auto": - parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker" - - assert not vllm_config.speculative_config, ( - "Speculative decoding is not yet supported for TPU backend" - ) - - if ( - scheduler_config.is_multimodal_model - and not scheduler_config.disable_chunked_mm_input - ): - logger.warning( - "TPU does not support running Multimodal models" - " without setting `--disable_chunked_mm_input`. " - "Forcing --disable_chunked_mm_input." - ) - scheduler_config.disable_chunked_mm_input = True - - if model_config and model_config.use_mla: - logger.info( - "MLA is enabled on a non-GPU platform; forcing chunked " - "prefill and prefix caching to be disabled." - ) - vllm_config.scheduler_config.enable_chunked_prefill = False - vllm_config.scheduler_config.max_num_batched_tokens = max( - vllm_config.model_config.max_model_len, - vllm_config.scheduler_config.DEFAULT_MAX_NUM_BATCHED_TOKENS, - ) - - @classmethod - def is_pin_memory_available(cls): - logger.warning("Pin memory is not supported on TPU.") - return False - - @classmethod - def get_device_communicator_cls(cls) -> str: - return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa - - @classmethod - def validate_request( - cls, - prompt: PromptType, - params: ParamsType, - processed_inputs: ProcessorInputs, - ) -> None: - """Raises if this request is unsupported on this platform""" - from vllm.sampling_params import SamplingParams, SamplingType - - if ( - isinstance(params, SamplingParams) - and params.sampling_type == SamplingType.RANDOM_SEED - ): - raise ValueError("Torch XLA does not support per-request seed.") - - @classmethod - @torch.compile(backend="openxla") - def insert_blocks_to_device( - cls, - src_cache: torch.Tensor, - dst_cache: torch.Tensor, - src_block_indices: torch.Tensor, - dst_block_indices: torch.Tensor, - ) -> None: - torch.ops.xla.dynamo_set_buffer_donor_(dst_cache, True) - dst_cache[dst_block_indices] = src_cache[src_block_indices].to(dst_cache.device) - - @classmethod - @torch.compile(backend="openxla") - def swap_out_blocks_to_host( - cls, - src_cache: torch.Tensor, - dst_cache: torch.Tensor, - src_block_indices: torch.Tensor, - dst_block_indices: torch.Tensor, - ) -> None: - """tpu blocks to cpu blocks""" - torch.ops.xla.dynamo_set_buffer_donor_(src_cache, True) - dst_cache[dst_block_indices] = src_cache[src_block_indices].cpu() - - @classmethod - def use_sync_weight_loader(cls) -> bool: - return True - - @classmethod - def check_max_model_len(cls, max_model_len: int) -> int: - """ - Check max_model_len for the current platform. - """ - logger.warning( - "--max-model-len is not specified, " - "it's currently using model's default length %d, " - "which might be too large." - "Please input with --max-model-len based on your " - "request input length and output length, to avoid " - "unnecessary degradation.", - max_model_len, - ) - return max_model_len - try: from tpu_inference.platforms import ( @@ -291,5 +40,5 @@ try: TpuPlatform = TpuInferencePlatform # type: ignore USE_TPU_INFERENCE = True except ImportError: - logger.info("tpu_inference not found, using vLLM's TpuPlatform") + logger.error("tpu_inference not found, please install tpu_inference to run vllm on TPU") pass diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 525026bac5a7e..5f7f8e81a24c4 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -33,70 +33,7 @@ TPU_STR_DTYPE_TO_TORCH_DTYPE = { "uint8": torch.uint8, } -try: - import tpu_inference # noqa: F401 -except ImportError: - # Lazy import torch_xla - import torch_xla.core.xla_builder as xb - import torch_xla.experimental.custom_kernel # noqa: F401 - from torch.library import impl - from torch_xla._internal.jax_workarounds import requires_jax - from torch_xla.experimental.custom_kernel import XLA_LIB - - @requires_jax - def kv_cache_update_op_impl( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - - new_kv_cache = xb.call_jax( - kv_cache_update, - (kv, slot_mapping, kv_cache, num_kv_update_slices), - {"page_size": page_size, "num_slices_per_block": num_slices_per_block}, - ) - return new_kv_cache - - XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping," - "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," - "int num_slices_per_block)" - "-> Tensor", - ) - - @impl(XLA_LIB, "kv_cache_update_op", "XLA") - def kv_cache_update_op_xla( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl( - kv, - slot_mapping, - kv_cache, - num_kv_update_slices, - page_size, - num_slices_per_block, - ) - return new_kv_cache - - @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") - def kv_cache_update_op_non_xla( - kv: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int, - ) -> torch.Tensor: - return kv_cache +import tpu_inference # noqa: F401 class PallasAttentionBackend(AttentionBackend): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 5f6136b178b46..b50def0e17de4 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -36,316 +36,6 @@ logger = init_logger(__name__) _R = TypeVar("_R") -if not USE_TPU_INFERENCE: - logger.info("tpu_inference not found, using vLLM's TPUWorker.") - import torch_xla.core.xla_model as xm - import torch_xla.debug.profiler as xp - import torch_xla.runtime as xr - - from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT - from vllm.v1.worker.tpu_model_runner import TPUModelRunner - - -class TPUWorker: - def __init__( - self, - vllm_config: VllmConfig, - local_rank: int, - rank: int, - distributed_init_method: str, - is_driver_worker: bool = False, - ): - self.is_driver_worker = is_driver_worker - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.cache_config = vllm_config.cache_config - self.lora_config = vllm_config.lora_config - self.load_config = vllm_config.load_config - self.parallel_config = vllm_config.parallel_config - self.use_spmd = envs.VLLM_XLA_USE_SPMD - self.original_parallel_config = None - if self.use_spmd: - # Under SPMD mode, distributed env is initialized as if there is - # only one worker/device. - self.original_parallel_config = self.parallel_config - self.parallel_config.tensor_parallel_size = 1 - self.parallel_config.pipeline_parallel_size = 1 - self.parallel_config.world_size = 1 - self.scheduler_config = vllm_config.scheduler_config - self.device_config = vllm_config.device_config - self.speculative_config = vllm_config.speculative_config - self.observability_config = vllm_config.observability_config - - self.parallel_config.rank = rank - self.local_rank = local_rank - self.rank = rank - self.distributed_init_method = distributed_init_method - - if self.cache_config.cache_dtype == "auto": - self.cache_dtype = self.model_config.dtype - else: - self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[self.cache_config.cache_dtype] - - if self.model_config.trust_remote_code: - # note: lazy import to avoid importing torch before initializing - from vllm.utils.import_utils import init_cached_hf_modules - - init_cached_hf_modules() - - # Delay profiler initialization to the start of the profiling. - # This is because in vLLM V1, MP runtime is initialized before the - # TPU Worker is initialized. The profiler server needs to start after - # MP runtime is initialized. - self.profiler = None - self.profile_dir = None - if vllm_config.profiler_config.profiler == "torch" and self.rank < 1: - # For TPU, we can only have 1 active profiler session for 1 profiler - # server. So we only profile on rank0. - self.profile_dir = vllm_config.profiler_config.torch_profiler_dir - logger.info( - "Profiling enabled. Traces will be saved to: %s", self.profile_dir - ) - - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - def init_device(self): - os.environ["PJRT_DEVICE"] = "TPU" - # Note: Currently the XLA compiler wrongly uses 2D ring strategy on 1D - # ring, the xla tpu compiler flag - # `xla_tpu_force_1d_allreduce_at_chunk_count` is a temporary solution to - # fix this. It will be removed after the bug in XLA compiler is fixed. - os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") - + " --xla_tpu_force_1d_allreduce_at_chunk_count=1" - " --xla_jf_conv_input_fusion=False" - ) - # --xla_jf_conv_input_fusion=False is used to improve the perf of - # quantized matmul. - torch.set_grad_enabled(False) - torch.set_default_dtype(self.model_config.dtype) - - # Initialize the distributed environment. - self._init_tpu_worker_distributed_environment( - self.vllm_config, self.rank, self.distributed_init_method, self.local_rank - ) - - # Device initialization should happen after initializing - # the distributed runtime. - self.device = xm.xla_device() - self.device_config.device = self.device - - # Set random seed. - set_random_seed(self.model_config.seed) - xm.set_rng_state(self.model_config.seed, self.device) - - # Increase the cache size limit, which is the maximum number of - # dynamo graphs that can be compiled. - # TODO (NickLucche) On gsm we compile 80+ graphs. - # Re-evaluate limit, with MM we may get close to this limit. - torch._dynamo.config.cache_size_limit = 128 - # Use persistent cache to avoid XLA recompilation. - # NOTE(woosuk): Set per-rank cache path since different ranks - # can have slightly different XLA graphs. - world_size = self.parallel_config.world_size - rank = xr.global_ordinal() - # The PyTorch/XLA compilation cache uses the Torch IR to generate keys. - # Consequently, changes in optimization flags, which affect compilation - # results, don't change the cache key. This can result in the wrong - # compilation being used. To prevent this, disabling the XLA compilation - # cache during development is recommended.We can disable it by - # `export VLLM_XLA_CACHE_PATH=` - if envs.VLLM_XLA_CACHE_PATH: - per_rank_path = os.path.join( - envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{rank}" - ) - xr.initialize_cache(per_rank_path, readonly=False) - - # Init ModelRunner here, so that we have access to self.device. - self.model_runner = TPUModelRunner( - self.vllm_config, self.device, self.original_parallel_config - ) - - if rank == 0: - # If usage stat is enabled, collect relevant info. - report_usage_stats(self.vllm_config) - - def determine_available_memory(self) -> int: - kv_caches: dict[str, torch.Tensor] = {} - kv_cache_spec = self.model_runner.get_kv_cache_spec() - for layer_name, layer_spec in kv_cache_spec.items(): - if isinstance(layer_spec, AttentionSpec): - dtype = layer_spec.dtype - - # Use an empty tensor instead of `None` to force Dynamo to pass - # it by reference, rather by specializing on the value `None`. - tpu_kv_cache = torch.tensor([], dtype=dtype).to(self.device) - kv_caches[layer_name] = tpu_kv_cache - else: - raise NotImplementedError( - f"Unsupported KV cache spec '{type(layer_spec)}'" - ) - - runner_kv_caches: list[torch.Tensor] = [] - bind_kv_cache( - kv_caches, - self.vllm_config.compilation_config.static_forward_context, - runner_kv_caches, - ) - - # `max_num_tokens >= max_num_batched_tokens` due to padding. - with self.model_runner.maybe_setup_dummy_loras(self.lora_config): - self.model_runner.profile_run(self.model_runner.max_num_tokens) - - # Synchronize before measuring the memory usage. - xm.wait_device_ops() - - # During the profiling run, the model runs without KV cache. After - # the profiling run, the model always runs with KV cache. Here we clear - # the dynamo cache and cached bytecode to ensure the model always has - # one compiled bytecode. Having one FX graph/cached bytecode per - # compiled model is required for `support_torch_compile` decorator to - # skip dynamo guard. - with set_current_vllm_config(self.vllm_config): - self.model_runner.reset_dynamo_cache() - - # Get the maximum amount of memory used by the model weights and - # intermediate activations. - if self.use_spmd: - # This is a workaround for the TPU SPMD mode. The get_memory_info - # API doesn't work with SPMD mode in PyTorch/XLA. - # TODO: use xm.get_memory_info for SPMD once it's supported in - # PyTorch/XLA. - import tpu_info - - chip_type, _ = tpu_info.device.get_local_chips() - device_usage = tpu_info.metrics.get_chip_usage(chip_type) - total_memory_size = device_usage[0].total_memory - current_mem = device_usage[0].memory_usage - else: - m = xm.get_memory_info(self.device) - total_memory_size = m["bytes_limit"] - current_mem = m["bytes_used"] - # Ideally we would use profiled = m["peak_bytes_used"] to - # get weights + activations. But there is memory used during - # compilation / weight loading that impacts the peak and - # there is no way to reset peak memory in XLA, So we - # use the heuristic of 2% of weights. - profiled = current_mem * 1.02 - - # Calculate the TPU KV cache size based on profiling. - usable_memory_size = int( - total_memory_size * self.cache_config.gpu_memory_utilization - ) - tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) - head_size = self.model_config.get_head_size() - if head_size > 0: - padded_head_size = ( - cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT - ) - if padded_head_size != head_size: - logger.warning_once("head size is padded to %d", padded_head_size) - # We adjust the usable memory size for the KV cache to prevent OOM - # errors, even after padding the head_size. - tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size - return int(tpu_kv_cache_bytes) - - def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput: - return self.model_runner.sample_tokens(grammar_output) - - def execute_model( - self, scheduler_output: "SchedulerOutput" - ) -> ModelRunnerOutput | None: - return self.model_runner.execute_model(scheduler_output) - - def profile(self, is_start: bool = True): - if self.rank < 1: - if self.profile_dir is None: - raise RuntimeError("Profiler is not enabled.") - if is_start: - if self.profiler is None: - self.profiler = xp.start_server(9012) - xp.start_trace(self.profile_dir) - else: - xp.stop_trace() - - def add_lora(self, lora_request: LoRARequest) -> bool: - return self.model_runner.add_lora(lora_request) - - def load_model(self) -> None: - self.model_runner.load_model() - - def update_config(self, overrides: dict[str, Any]) -> None: - self.model_runner.update_config(overrides) - - def reload_weights(self) -> None: - self.model_runner.reload_weights() - - def compile_or_warm_up_model(self) -> None: - if not self.model_config.enforce_eager: - self.model_runner.capture_model() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) - - def reset_mm_cache(self) -> None: - self.model_runner.reset_mm_cache() - - def get_model(self) -> nn.Module: - return self.model_runner.get_model() - - def get_supported_tasks(self) -> tuple[SupportedTask, ...]: - return self.model_runner.get_supported_tasks() - - def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: - return self.model_runner.get_kv_cache_spec() - - def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None: - """Allocate GPU KV cache with the specified kv_cache_config.""" - self.model_runner.initialize_kv_cache(kv_cache_config) - - def check_health(self) -> None: - # worker will always be healthy as long as it's running. - return - - def _init_tpu_worker_distributed_environment( - self, - vllm_config: VllmConfig, - rank: int, - distributed_init_method: str | None = None, - local_rank: int = -1, - ) -> None: - """Initialize the distributed environment.""" - if self.use_spmd: - xr.use_spmd() - # NOTE(woosuk): This is just to initialize the TP group and broadcast - # the input objects on CPU. The all-reduce and all-gather ops on TPU - # are invoked by `xm.all_reduce` and `xm.all_gather` which use their - # own context. - parallel_config = vllm_config.parallel_config - init_distributed_environment( - world_size=parallel_config.world_size, - rank=rank, - local_rank=local_rank, - distributed_init_method=distributed_init_method or "env://", - backend=current_platform.dist_backend, - ) - ensure_model_parallel_initialized( - parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size - ) - - ensure_kv_transfer_initialized(vllm_config) - - def shutdown(self) -> None: - self.model_runner.ensure_kv_transfer_shutdown() - - def apply_model(self, fn: Callable[[nn.Module], _R]) -> _R: - """Apply a function on the model inside this worker.""" - return fn(self.get_model()) - - if USE_TPU_INFERENCE: from tpu_inference.worker.tpu_worker import TPUWorker as TpuInferenceWorker