Merge branch 'main' into v1-sched-interface-2

This commit is contained in:
Woosuk Kwon 2025-03-12 21:45:44 -07:00
commit 6cd1b1a18c
9 changed files with 51 additions and 23 deletions

View File

@ -9,6 +9,7 @@ setuptools-scm>=8
wheel wheel
jinja2 jinja2
ray[default] ray[default]
ray[data]
# Install torch_xla # Install torch_xla
--pre --pre

View File

@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection # The KV connector port, used to build distributed connection
kv_port: int = 14579 kv_port: int = 14579
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,
@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \ return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"] self.kv_role in ["kv_consumer", "kv_both"]
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class CompilationLevel: class CompilationLevel:
# constants for the levels of the compilation process # constants for the levels of the compilation process

View File

@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
self.device = self._select_device(device) self.device = self._select_device(device)
# build distributed connection and send/recv implementation # build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create( self.group = StatelessProcessGroup.create(
host=self.config.kv_ip, host=self.config.kv_ip,
port=self.config.kv_port + port_offset, port=self.config.kv_port + port_offset,
rank=self.kv_rank, rank=self.kv_rank,
world_size=self.kv_parallel_size, world_size=self.kv_parallel_size,
store_timeout=store_timeout,
) )
# add a barrier to make sure the connection is initiated properly # add a barrier to make sure the connection is initiated properly
self.group.barrier() self.group.barrier()

View File

@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses import dataclasses
import datetime
import pickle import pickle
import time import time
from collections import deque from collections import deque
@ -217,6 +218,7 @@ class StatelessProcessGroup:
rank: int, rank: int,
world_size: int, world_size: int,
data_expiration_seconds: int = 3600, data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup": ) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not """A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. pollute the global state.
@ -238,6 +240,7 @@ class StatelessProcessGroup:
port=port, port=port,
world_size=world_size, world_size=world_size,
is_master=(rank == 0), is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
) )
return StatelessProcessGroup( return StatelessProcessGroup(

View File

@ -50,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# We prefer to use separate k_scale and v_scale if present # We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist() k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
k_scale *= 2 k_scale *= 2
v_scale *= 2 v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0: elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@ -66,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale) scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist() k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
k_scale *= 2 k_scale *= 2
v_scale *= 2 v_scale *= 2

View File

@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path, gguf_weights_map): local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True}) model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with target_device:
model = _initialize_model(vllm_config=vllm_config) model = _initialize_model(vllm_config=vllm_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)) self._get_weights_iterator(local_model_path, gguf_weights_map))
_process_weights_after_loading(model, model_config, target_device)
return model return model

View File

@ -91,13 +91,19 @@ class TpuPlatform(Platform):
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto": if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1: if scheduler_config.is_multi_step:
parallel_config.worker_cls = \ if envs.VLLM_USE_V1:
"vllm.v1.worker.tpu_worker.TPUWorker" raise NotImplementedError(
else: "Multi-step scheduling is not supported (and not "
if scheduler_config.is_multi_step: "needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else: else:
parallel_config.worker_cls = \ parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker" "vllm.worker.tpu_worker.TPUWorker"

View File

@ -23,6 +23,7 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx) zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MMInputCacheServer
@ -67,15 +68,21 @@ class EngineCore:
# Setup scheduler. # Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str): if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning( logger.warning(
"Using configured V1 scheduler class %s. " "Using configured V1 scheduler class %s. "
"This scheduler interface is not public and " "This scheduler interface is not public and "
"compatibility may not be maintained.", "compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls) vllm_config.scheduler_config.scheduler_cls)
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
self.scheduler = Scheduler( self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config, scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config, model_config=vllm_config.model_config,