From 165290d357bf0b95a78a96a45cb7e0dd98d494de Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Wed, 12 Mar 2025 21:12:13 -0600 Subject: [PATCH 1/6] [bugfix] fixup warning message for plugged schedulers for v1 (#14700) Signed-off-by: Joe Runde --- vllm/v1/engine/core.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5a4e67a2dd78f..174d96ec43776 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -22,6 +22,7 @@ from vllm.transformers_utils.config import ( from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname, zmq_socket_ctx) from vllm.v1.core.kv_cache_utils import get_kv_cache_configs +from vllm.v1.core.scheduler import Scheduler as V1Scheduler from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) @@ -67,15 +68,21 @@ class EngineCore: # Setup scheduler. if isinstance(vllm_config.scheduler_config.scheduler_cls, str): + Scheduler = resolve_obj_by_qualname( + vllm_config.scheduler_config.scheduler_cls) + else: + Scheduler = vllm_config.scheduler_config.scheduler_cls + + # This warning can be removed once the V1 Scheduler interface is + # finalized and we can maintain support for scheduler classes that + # implement it + if Scheduler is not V1Scheduler: logger.warning( "Using configured V1 scheduler class %s. " "This scheduler interface is not public and " "compatibility may not be maintained.", vllm_config.scheduler_config.scheduler_cls) - Scheduler = resolve_obj_by_qualname( - vllm_config.scheduler_config.scheduler_cls) - else: - Scheduler = vllm_config.scheduler_config.scheduler_cls + self.scheduler = Scheduler( scheduler_config=vllm_config.scheduler_config, model_config=vllm_config.model_config, From ab426ec9c04505d311ef222d8609c6eec729248e Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Wed, 12 Mar 2025 20:13:48 -0700 Subject: [PATCH 2/6] Add ray[data] as tpu dependency (#14691) Signed-off-by: Signed-off-by: Richard Liu --- requirements/tpu.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/tpu.txt b/requirements/tpu.txt index e8e3b0af5db8c..e071c604b5c0b 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -9,6 +9,7 @@ setuptools-scm>=8 wheel jinja2 ray[default] +ray[data] # Install torch_xla --pre From a94a699c3ff9bfc23a35f147150a826b753bbf6a Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Wed, 12 Mar 2025 23:14:04 -0400 Subject: [PATCH 3/6] [ROCm][FP8] Fix for adjustments needed only for fnuz (#14689) Signed-off-by: Gregory Shtrasberg --- vllm/model_executor/layers/quantization/kv_cache.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 388a4f16699c5..92990487885b9 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -50,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): k_scale *= 2 v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: @@ -66,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase): scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): k_scale *= 2 v_scale *= 2 From 128bf7528370d792099c66f301c6c5deef8f4110 Mon Sep 17 00:00:00 2001 From: TY-AMD Date: Thu, 13 Mar 2025 11:14:36 +0800 Subject: [PATCH 4/6] [BugFix][TritonMLA] Process weights after model loading for GGUF (#14555) Signed-off-by: TianyuanWu --- vllm/model_executor/model_loader/loader.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index bf226f6611262..c88af56e18053 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader): local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) + target_device = torch.device(device_config.device) with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): + with target_device: model = _initialize_model(vllm_config=vllm_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) + + _process_weights_after_loading(model, model_config, target_device) return model From 1bd32bc8dd685b1bcddc0e7408a46a7c637dae8a Mon Sep 17 00:00:00 2001 From: Mathis Felardos Date: Thu, 13 Mar 2025 04:15:20 +0100 Subject: [PATCH 5/6] [Config][Disaggregated] Add timeout configuration for the torch.store and add KVTransferConfig.kv_connector_extra_config (#14367) Signed-off-by: Mathis Felardos --- vllm/config.py | 6 +++++ .../kv_lookup_buffer/simple_buffer.py | 2 +- .../kv_transfer/kv_pipe/pynccl_pipe.py | 22 ++++++++++--------- vllm/distributed/utils.py | 3 +++ 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index aa8b16920a97f..3ac7ceabd8d3d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel): # The KV connector port, used to build distributed connection kv_port: int = 14579 + # any extra config that the connector may need + kv_connector_extra_config: dict[str, Any] = {} + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, @@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel): return self.kv_connector is not None and \ self.kv_role in ["kv_consumer", "kv_both"] + def get_from_extra_config(self, key, default) -> Any: + return self.kv_connector_extra_config.get(key, default) + class CompilationLevel: # constants for the levels of the compilation process diff --git a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py index 3462f7de020ef..10bbfe1ddd8a2 100644 --- a/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +++ b/vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py @@ -6,7 +6,7 @@ - Distributed KV cache transmission using PyNccl pipes. - Non-blocking `insert`, blocking `drop_select`. - Use CPU signal pipe to avoid racing condition - - Handles buffer size constraints and provide backpressure mechanism to + - Handles buffer size constraints and provide backpressure mechanism to stop the prefill instance when the decode instance is slow. """ import threading diff --git a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py index 7aa53d07a9ef2..e8bf607eb8993 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 """ - This module implements a PyNccl pipe for sending and receiving - Optional[torch.Tensor] between distributed ranks with advanced + This module implements a PyNccl pipe for sending and receiving + Optional[torch.Tensor] between distributed ranks with advanced communication features. Key Features: @@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase): self.device = self._select_device(device) # build distributed connection and send/recv implementation + store_timeout = self.config.get_from_extra_config("store_timeout", 300) self.group = StatelessProcessGroup.create( host=self.config.kv_ip, port=self.config.kv_port + port_offset, rank=self.kv_rank, world_size=self.kv_parallel_size, + store_timeout=store_timeout, ) # add a barrier to make sure the connection is initiated properly self.group.barrier() @@ -134,11 +136,11 @@ class PyNcclPipe(KVPipeBase): Create a buffer to receive the tensor based on the provided metadata. Parameters: - - metadata: A dictionary with keys "dtype" and "shape", describing + - metadata: A dictionary with keys "dtype" and "shape", describing the tensor's data type and shape. Returns: - - buffer: A tensor of the specified type and shape, allocated on + - buffer: A tensor of the specified type and shape, allocated on self.device. """ return torch.empty(metadata["shape"], @@ -159,18 +161,18 @@ class PyNcclPipe(KVPipeBase): Receive the metadata dictionary from the target rank. Returns: - - metadata: A dictionary with keys "dtype" and "shape" describing + - metadata: A dictionary with keys "dtype" and "shape" describing the tensor. """ return self.group.recv_obj(self.target_rank_for_recv) def _send_impl(self, tensor: Optional[torch.Tensor]) -> None: """ - The actual implementation of sending the tensor and its metadata to the + The actual implementation of sending the tensor and its metadata to the target rank. Parameters: - - tensor: The input tensor to be sent, or None if no tensor is + - tensor: The input tensor to be sent, or None if no tensor is being sent. """ metadata = self._make_metadata(tensor) @@ -181,7 +183,7 @@ class PyNcclPipe(KVPipeBase): def _recv_impl(self) -> Optional[torch.Tensor]: """ - The actual implementation of receiving a tensor and its metadata from + The actual implementation of receiving a tensor and its metadata from the target rank. Returns: @@ -213,7 +215,7 @@ class PyNcclPipe(KVPipeBase): def block_if_full(self): """ - Block the current thread if the buffer size is larger than the + Block the current thread if the buffer size is larger than the threshold. """ while self.buffer_size > self.buffer_size_thresh: @@ -222,7 +224,7 @@ class PyNcclPipe(KVPipeBase): def send_tensor(self, tensor: Optional[torch.Tensor]) -> None: """ - Sends a tensor and its metadata to the destination rank in a + Sends a tensor and its metadata to the destination rank in a non-blocking way. Parameters: diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index d6fca4f0221b8..25202062e9757 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -5,6 +5,7 @@ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. import dataclasses +import datetime import pickle import time from collections import deque @@ -217,6 +218,7 @@ class StatelessProcessGroup: rank: int, world_size: int, data_expiration_seconds: int = 3600, + store_timeout: int = 300, ) -> "StatelessProcessGroup": """A replacement for `torch.distributed.init_process_group` that does not pollute the global state. @@ -238,6 +240,7 @@ class StatelessProcessGroup: port=port, world_size=world_size, is_master=(rank == 0), + timeout=datetime.timedelta(seconds=store_timeout), ) return StatelessProcessGroup( From 1bc3b739c4421a5d63b527fc2b5335e90e450204 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 12 Mar 2025 21:37:58 -0700 Subject: [PATCH 6/6] [V1][TPU] Add assertion on multi-step-scheduler (#14707) Signed-off-by: Siyuan Liu --- vllm/platforms/tpu.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 0b66b52713e97..fc68e5d63a6e5 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -91,13 +91,19 @@ class TpuPlatform(Platform): parallel_config = vllm_config.parallel_config scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": - if envs.VLLM_USE_V1: - parallel_config.worker_cls = \ - "vllm.v1.worker.tpu_worker.TPUWorker" - else: - if scheduler_config.is_multi_step: + if scheduler_config.is_multi_step: + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on vLLM V1. Please launch without " + "--num-scheduler-steps.") + else: parallel_config.worker_cls = \ "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.tpu_worker.TPUWorker" else: parallel_config.worker_cls = \ "vllm.worker.tpu_worker.TPUWorker"