mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 06:34:58 +08:00
[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
bc192a2b09
commit
28b3a1c7e5
@ -26,6 +26,14 @@ MODELS = [
|
|||||||
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def v1(run_with_both_engines):
|
||||||
|
# Simple autouse wrapper to run both engines for each test
|
||||||
|
# This can be promoted up to conftest.py to run for every
|
||||||
|
# test in a package
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def test_vllm_gc_ed():
|
def test_vllm_gc_ed():
|
||||||
"""Verify vllm instance is GC'ed when it is deleted"""
|
"""Verify vllm instance is GC'ed when it is deleted"""
|
||||||
llm = LLM("facebook/opt-125m")
|
llm = LLM("facebook/opt-125m")
|
||||||
@ -36,6 +44,7 @@ def test_vllm_gc_ed():
|
|||||||
assert weak_llm() is None
|
assert weak_llm() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -118,6 +127,11 @@ def test_models_distributed(
|
|||||||
if attention_backend:
|
if attention_backend:
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
|
os.environ["VLLM_ATTENTION_BACKEND"] = attention_backend
|
||||||
|
|
||||||
|
# Import VLLM_USE_V1 dynamically to handle patching
|
||||||
|
from vllm.envs import VLLM_USE_V1
|
||||||
|
if VLLM_USE_V1 and distributed_executor_backend != "mp":
|
||||||
|
pytest.skip(f"Skip {distributed_executor_backend} for V1")
|
||||||
|
|
||||||
dtype = "half"
|
dtype = "half"
|
||||||
max_tokens = 5
|
max_tokens = 5
|
||||||
|
|
||||||
@ -143,6 +157,7 @@ def test_models_distributed(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
def test_model_with_failure(vllm_runner) -> None:
|
def test_model_with_failure(vllm_runner) -> None:
|
||||||
try:
|
try:
|
||||||
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
|
with patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
|
||||||
@ -169,6 +184,7 @@ def test_model_with_failure(vllm_runner) -> None:
|
|||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_v1
|
||||||
def test_failure_with_async_out_proc(vllm_runner) -> None:
|
def test_failure_with_async_out_proc(vllm_runner) -> None:
|
||||||
|
|
||||||
filename = None
|
filename = None
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from collections import UserList
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
|
from typing import (Any, Callable, Dict, List, Optional, Tuple, Type,
|
||||||
TypedDict, TypeVar, Union)
|
TypedDict, TypeVar, Union)
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -110,7 +109,7 @@ VIDEO_ASSETS = _VideoAssets()
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(params=[True, False])
|
@pytest.fixture(params=[True, False])
|
||||||
def run_with_both_engines(request):
|
def run_with_both_engines(request, monkeypatch):
|
||||||
# Automatically runs tests twice, once with V1 and once without
|
# Automatically runs tests twice, once with V1 and once without
|
||||||
use_v1 = request.param
|
use_v1 = request.param
|
||||||
# Tests decorated with `@skip_v1` are only run without v1
|
# Tests decorated with `@skip_v1` are only run without v1
|
||||||
@ -119,11 +118,11 @@ def run_with_both_engines(request):
|
|||||||
if use_v1:
|
if use_v1:
|
||||||
if skip_v1:
|
if skip_v1:
|
||||||
pytest.skip("Skipping test on vllm V1")
|
pytest.skip("Skipping test on vllm V1")
|
||||||
with patch('vllm.envs.VLLM_USE_V1', True):
|
monkeypatch.setenv('VLLM_USE_V1', '1')
|
||||||
yield
|
|
||||||
else:
|
else:
|
||||||
with patch('vllm.envs.VLLM_USE_V1', False):
|
monkeypatch.setenv('VLLM_USE_V1', '0')
|
||||||
yield
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from multiprocessing import shared_memory
|
from multiprocessing import shared_memory
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -21,6 +22,20 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# We prefer to use os.sched_yield as it results in tighter polling loops,
|
||||||
|
# measured to be around 3e-7 seconds. However on earlier versions of Python
|
||||||
|
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
|
||||||
|
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
|
||||||
|
or (sys.version_info[:2] == (3, 10)
|
||||||
|
and sys.version_info[2] >= 8))
|
||||||
|
|
||||||
|
|
||||||
|
def sched_yield():
|
||||||
|
if USE_SCHED_YIELD:
|
||||||
|
os.sched_yield()
|
||||||
|
else:
|
||||||
|
time.sleep(0)
|
||||||
|
|
||||||
|
|
||||||
class ShmRingBuffer:
|
class ShmRingBuffer:
|
||||||
|
|
||||||
@ -114,11 +129,14 @@ class ShmRingBuffer:
|
|||||||
# and we should suppress the error
|
# and we should suppress the error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
||||||
|
self.shared_memory.name)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (
|
return (
|
||||||
self.__class__,
|
self.__class__,
|
||||||
(self.n_reader, self.max_chunk_bytes, self.max_chunks,
|
self.handle(),
|
||||||
self.shared_memory.name),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@ -147,7 +165,7 @@ class Handle:
|
|||||||
connect_ip: str
|
connect_ip: str
|
||||||
local_reader_ranks: List[int] = field(default_factory=list)
|
local_reader_ranks: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
buffer: Optional[ShmRingBuffer] = None
|
buffer_handle: Optional[Tuple[int, int, int, str]] = None
|
||||||
local_subscribe_port: Optional[int] = None
|
local_subscribe_port: Optional[int] = None
|
||||||
remote_subscribe_port: Optional[int] = None
|
remote_subscribe_port: Optional[int] = None
|
||||||
|
|
||||||
@ -228,7 +246,7 @@ class MessageQueue:
|
|||||||
self.handle = Handle(
|
self.handle = Handle(
|
||||||
connect_ip=connect_ip,
|
connect_ip=connect_ip,
|
||||||
local_reader_ranks=local_reader_ranks,
|
local_reader_ranks=local_reader_ranks,
|
||||||
buffer=self.buffer,
|
buffer_handle=self.buffer.handle(),
|
||||||
local_subscribe_port=local_subscribe_port,
|
local_subscribe_port=local_subscribe_port,
|
||||||
remote_subscribe_port=remote_subscribe_port,
|
remote_subscribe_port=remote_subscribe_port,
|
||||||
)
|
)
|
||||||
@ -247,8 +265,8 @@ class MessageQueue:
|
|||||||
context = Context()
|
context = Context()
|
||||||
|
|
||||||
if rank in handle.local_reader_ranks:
|
if rank in handle.local_reader_ranks:
|
||||||
assert handle.buffer is not None
|
assert handle.buffer_handle is not None
|
||||||
self.buffer = handle.buffer
|
self.buffer = ShmRingBuffer(*handle.buffer_handle)
|
||||||
self.current_idx = 0
|
self.current_idx = 0
|
||||||
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
self.local_reader_rank = handle.local_reader_ranks.index(rank)
|
||||||
self._is_local_reader = True
|
self._is_local_reader = True
|
||||||
@ -314,7 +332,7 @@ class MessageQueue:
|
|||||||
assert recv == b"READY"
|
assert recv == b"READY"
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_write(self):
|
def acquire_write(self, timeout: Optional[float] = None):
|
||||||
assert self._is_writer, "Only writers can acquire write"
|
assert self._is_writer, "Only writers can acquire write"
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
@ -329,16 +347,20 @@ class MessageQueue:
|
|||||||
# we need to wait until it is read by all readers
|
# we need to wait until it is read by all readers
|
||||||
|
|
||||||
# Release the processor to other threads
|
# Release the processor to other threads
|
||||||
os.sched_yield()
|
sched_yield()
|
||||||
|
|
||||||
# if we wait for a long time, we should warn the user
|
# if we wait for a long time, log a message
|
||||||
if (time.monotonic() - start_time >
|
if (time.monotonic() - start_time >
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||||
logger.warning(
|
logger.debug("No available block found in %s second. ",
|
||||||
"No available block found in %s second. ",
|
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
|
|
||||||
|
# if we time out, raise an exception
|
||||||
|
if (timeout is not None
|
||||||
|
and time.monotonic() - start_time > timeout):
|
||||||
|
raise TimeoutError
|
||||||
|
|
||||||
continue
|
continue
|
||||||
# found a block that is either
|
# found a block that is either
|
||||||
# (1) not written
|
# (1) not written
|
||||||
@ -365,7 +387,7 @@ class MessageQueue:
|
|||||||
break
|
break
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def acquire_read(self):
|
def acquire_read(self, timeout: Optional[float] = None):
|
||||||
assert self._is_local_reader, "Only readers can acquire read"
|
assert self._is_local_reader, "Only readers can acquire read"
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
n_warning = 1
|
n_warning = 1
|
||||||
@ -383,16 +405,20 @@ class MessageQueue:
|
|||||||
# we need to wait until it is written
|
# we need to wait until it is written
|
||||||
|
|
||||||
# Release the processor to other threads
|
# Release the processor to other threads
|
||||||
os.sched_yield()
|
sched_yield()
|
||||||
|
|
||||||
# if we wait for a long time, we should warn the user
|
# if we wait for a long time, log a message
|
||||||
if (time.monotonic() - start_time >
|
if (time.monotonic() - start_time >
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
|
||||||
logger.warning(
|
logger.debug("No available block found in %s second. ",
|
||||||
"No available block found in %s second. ",
|
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
||||||
VLLM_RINGBUFFER_WARNING_INTERVAL)
|
|
||||||
n_warning += 1
|
n_warning += 1
|
||||||
|
|
||||||
|
# if we time out, raise an exception
|
||||||
|
if (timeout is not None
|
||||||
|
and time.monotonic() - start_time > timeout):
|
||||||
|
raise TimeoutError
|
||||||
|
|
||||||
continue
|
continue
|
||||||
# found a block that is not read by this reader
|
# found a block that is not read by this reader
|
||||||
# let caller read from the buffer
|
# let caller read from the buffer
|
||||||
@ -406,24 +432,26 @@ class MessageQueue:
|
|||||||
1) % self.buffer.max_chunks
|
1) % self.buffer.max_chunks
|
||||||
break
|
break
|
||||||
|
|
||||||
def enqueue(self, obj):
|
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||||
|
""" Write to message queue with optional timeout (in seconds) """
|
||||||
assert self._is_writer, "Only writers can enqueue"
|
assert self._is_writer, "Only writers can enqueue"
|
||||||
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
if self.n_local_reader > 0:
|
if self.n_local_reader > 0:
|
||||||
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
if len(serialized_obj) >= self.buffer.max_chunk_bytes:
|
||||||
with self.acquire_write() as buf:
|
with self.acquire_write(timeout) as buf:
|
||||||
buf[0] = 1 # overflow
|
buf[0] = 1 # overflow
|
||||||
self.local_socket.send(serialized_obj)
|
self.local_socket.send(serialized_obj)
|
||||||
else:
|
else:
|
||||||
with self.acquire_write() as buf:
|
with self.acquire_write(timeout) as buf:
|
||||||
buf[0] = 0 # not overflow
|
buf[0] = 0 # not overflow
|
||||||
buf[1:len(serialized_obj) + 1] = serialized_obj
|
buf[1:len(serialized_obj) + 1] = serialized_obj
|
||||||
if self.n_remote_reader > 0:
|
if self.n_remote_reader > 0:
|
||||||
self.remote_socket.send(serialized_obj)
|
self.remote_socket.send(serialized_obj)
|
||||||
|
|
||||||
def dequeue(self):
|
def dequeue(self, timeout: Optional[float] = None):
|
||||||
|
""" Read from message queue with optional timeout (in seconds) """
|
||||||
if self._is_local_reader:
|
if self._is_local_reader:
|
||||||
with self.acquire_read() as buf:
|
with self.acquire_read(timeout) as buf:
|
||||||
overflow = buf[0] == 1
|
overflow = buf[0] == 1
|
||||||
if not overflow:
|
if not overflow:
|
||||||
# no need to know the size of serialized object
|
# no need to know the size of serialized object
|
||||||
|
|||||||
@ -3,25 +3,19 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
|
||||||
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
DistributedGPUExecutor, DistributedGPUExecutorAsync)
|
||||||
from vllm.executor.gpu_executor import create_worker
|
from vllm.executor.gpu_executor import create_worker
|
||||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
ResultHandler, WorkerMonitor)
|
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
|
||||||
|
set_multiprocessing_worker_envs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.triton_utils.importing import HAS_TRITON
|
|
||||||
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
|
||||||
cuda_is_initialized, get_distributed_init_method,
|
get_distributed_init_method, get_open_port, make_async,
|
||||||
get_open_port, make_async,
|
|
||||||
update_environment_variables)
|
update_environment_variables)
|
||||||
|
|
||||||
if HAS_TRITON:
|
|
||||||
from vllm.triton_utils import maybe_set_triton_cache_manager
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -37,30 +31,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
world_size = self.parallel_config.world_size
|
world_size = self.parallel_config.world_size
|
||||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
|
|
||||||
# Disable torch async compiling which won't work with daemonic processes
|
# Set multiprocessing envs that are common to V0 and V1
|
||||||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
set_multiprocessing_worker_envs(self.parallel_config)
|
||||||
|
|
||||||
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
|
||||||
#
|
|
||||||
# Helps to avoid CPU contention. The default of spawning a thread per
|
|
||||||
# core combined with multiprocessing for each GPU can have a negative
|
|
||||||
# impact on performance. The contention is amplified when running in a
|
|
||||||
# container where CPU limits can cause throttling.
|
|
||||||
default_omp_num_threads = 1
|
|
||||||
if "OMP_NUM_THREADS" not in os.environ and (
|
|
||||||
current_parallelism :=
|
|
||||||
torch.get_num_threads()) > default_omp_num_threads:
|
|
||||||
logger.warning(
|
|
||||||
"Reducing Torch parallelism from %d threads to %d to avoid "
|
|
||||||
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
|
||||||
"external environment to tune this value as needed.",
|
|
||||||
current_parallelism, default_omp_num_threads)
|
|
||||||
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
|
||||||
torch.set_num_threads(default_omp_num_threads)
|
|
||||||
|
|
||||||
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
|
||||||
if HAS_TRITON and world_size > 1:
|
|
||||||
maybe_set_triton_cache_manager()
|
|
||||||
|
|
||||||
# Multiprocessing-based executor does not support multi-node setting.
|
# Multiprocessing-based executor does not support multi-node setting.
|
||||||
# Since it only works for single node, we can use the loopback address
|
# Since it only works for single node, we can use the loopback address
|
||||||
@ -122,13 +94,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
|||||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||||
})
|
})
|
||||||
|
|
||||||
if (cuda_is_initialized()
|
|
||||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
|
||||||
logger.warning("CUDA was previously initialized. We must use "
|
|
||||||
"the `spawn` multiprocessing start method. Setting "
|
|
||||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
|
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
||||||
|
|
||||||
cuda_device_count = cuda_device_count_stateless()
|
cuda_device_count = cuda_device_count_stateless()
|
||||||
# Use confusing message for more common TP-only case.
|
# Use confusing message for more common TP-only case.
|
||||||
assert tensor_parallel_size <= cuda_device_count, (
|
assert tensor_parallel_size <= cuda_device_count, (
|
||||||
|
|||||||
@ -11,8 +11,15 @@ from multiprocessing.process import BaseProcess
|
|||||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
||||||
TypeVar, Union)
|
TypeVar, Union)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.triton_utils.importing import HAS_TRITON
|
||||||
|
from vllm.utils import cuda_is_initialized
|
||||||
|
|
||||||
|
if HAS_TRITON:
|
||||||
|
from vllm.triton_utils import maybe_set_triton_cache_manager
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -270,3 +277,38 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
|||||||
def get_mp_context():
|
def get_mp_context():
|
||||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||||
return multiprocessing.get_context(mp_method)
|
return multiprocessing.get_context(mp_method)
|
||||||
|
|
||||||
|
|
||||||
|
def set_multiprocessing_worker_envs(parallel_config):
|
||||||
|
""" Set up environment variables that should be used when there are workers
|
||||||
|
in a multiprocessing environment. This should be called by the parent
|
||||||
|
process before worker processes are created"""
|
||||||
|
|
||||||
|
if (cuda_is_initialized()
|
||||||
|
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||||
|
logger.warning("CUDA was previously initialized. We must use "
|
||||||
|
"the `spawn` multiprocessing start method. Setting "
|
||||||
|
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
|
||||||
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
|
|
||||||
|
# Configure thread parallelism if OMP_NUM_THREADS isn't set
|
||||||
|
#
|
||||||
|
# Helps to avoid CPU contention. The default of spawning a thread per
|
||||||
|
# core combined with multiprocessing for each GPU can have a negative
|
||||||
|
# impact on performance. The contention is amplified when running in a
|
||||||
|
# container where CPU limits can cause throttling.
|
||||||
|
default_omp_num_threads = 1
|
||||||
|
if "OMP_NUM_THREADS" not in os.environ and (
|
||||||
|
current_parallelism :=
|
||||||
|
torch.get_num_threads()) > default_omp_num_threads:
|
||||||
|
logger.warning(
|
||||||
|
"Reducing Torch parallelism from %d threads to %d to avoid "
|
||||||
|
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
|
||||||
|
"external environment to tune this value as needed.",
|
||||||
|
current_parallelism, default_omp_num_threads)
|
||||||
|
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
|
||||||
|
torch.set_num_threads(default_omp_num_threads)
|
||||||
|
|
||||||
|
# workaround for https://github.com/vllm-project/vllm/issues/6103
|
||||||
|
if HAS_TRITON and parallel_config.world_size > 1:
|
||||||
|
maybe_set_triton_cache_manager()
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||||
tensor_model_parallel_gather)
|
tensor_model_parallel_gather)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -42,7 +43,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
# Soft cap the logits. Used in Gemma 2.
|
# Soft cap the logits. Used in Gemma 2.
|
||||||
self.soft_cap = soft_cap
|
self.soft_cap = soft_cap
|
||||||
# Whether to use gather or all-gather to gather the logits.
|
# Whether to use gather or all-gather to gather the logits.
|
||||||
self.use_gather = not current_platform.is_tpu()
|
|
||||||
|
self.use_gather = not current_platform.is_tpu(
|
||||||
|
) and not envs.VLLM_USE_V1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from typing_extensions import ParamSpec
|
|||||||
|
|
||||||
# import custom ops, trigger op registration
|
# import custom ops, trigger op registration
|
||||||
import vllm._C # noqa
|
import vllm._C # noqa
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||||
@ -110,17 +111,28 @@ class CudaPlatformBase(Platform):
|
|||||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if scheduler_config.is_multi_step:
|
if scheduler_config.is_multi_step:
|
||||||
parallel_config.worker_cls = \
|
if envs.VLLM_USE_V1:
|
||||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||||
elif vllm_config.speculative_config:
|
elif vllm_config.speculative_config:
|
||||||
parallel_config.worker_cls = \
|
if envs.VLLM_USE_V1:
|
||||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
raise NotImplementedError
|
||||||
parallel_config.sd_worker_cls = \
|
else:
|
||||||
"vllm.worker.worker.Worker"
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = \
|
||||||
|
"vllm.worker.worker.Worker"
|
||||||
else:
|
else:
|
||||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
if envs.VLLM_USE_V1:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.v1.worker.gpu_worker.Worker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
@ -249,4 +261,4 @@ try:
|
|||||||
if not isinstance(pynvml, _MockModule):
|
if not isinstance(pynvml, _MockModule):
|
||||||
CudaPlatform.log_warnings()
|
CudaPlatform.log_warnings()
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
CudaPlatform.log_warnings()
|
CudaPlatform.log_warnings()
|
||||||
|
|||||||
@ -10,6 +10,7 @@ import importlib.util
|
|||||||
import inspect
|
import inspect
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import os
|
import os
|
||||||
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@ -1652,3 +1653,28 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
|
|||||||
module_name, obj_name = qualname.rsplit(".", 1)
|
module_name, obj_name = qualname.rsplit(".", 1)
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
return getattr(module, obj_name)
|
return getattr(module, obj_name)
|
||||||
|
|
||||||
|
|
||||||
|
def kill_process_tree(pid: int):
|
||||||
|
"""
|
||||||
|
Kills all descendant processes of the given pid by sending SIGKILL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pid (int): Process ID of the parent process
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parent = psutil.Process(pid)
|
||||||
|
except psutil.NoSuchProcess:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get all children recursively
|
||||||
|
children = parent.children(recursive=True)
|
||||||
|
|
||||||
|
# Send SIGKILL to all children first
|
||||||
|
for child in children:
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
os.kill(child.pid, signal.SIGKILL)
|
||||||
|
|
||||||
|
# Finally kill the parent
|
||||||
|
with contextlib.suppress(ProcessLookupError):
|
||||||
|
os.kill(pid, signal.SIGKILL)
|
||||||
|
|||||||
@ -5,6 +5,8 @@ from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set,
|
|||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.multimodal.base import PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||||
@ -383,7 +385,7 @@ class Scheduler:
|
|||||||
model_runner_output: "ModelRunnerOutput",
|
model_runner_output: "ModelRunnerOutput",
|
||||||
) -> List[EngineCoreOutput]:
|
) -> List[EngineCoreOutput]:
|
||||||
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
# NOTE(woosuk): This method doesn't consider speculative decoding.
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
new_running: List[Request] = []
|
new_running: List[Request] = []
|
||||||
engine_core_outputs: List[EngineCoreOutput] = []
|
engine_core_outputs: List[EngineCoreOutput] = []
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.v1.engine.async_stream import AsyncStream
|
|||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.detokenizer import Detokenizer
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -30,7 +30,7 @@ class AsyncLLM(EngineClient):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: Type[GPUExecutor],
|
executor_class: Type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
@ -119,14 +119,24 @@ class AsyncLLM(EngineClient):
|
|||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
"""Shutdown, cleaning up the background proc and IPC."""
|
"""Shutdown, cleaning up the background proc and IPC."""
|
||||||
|
|
||||||
self.engine_core.shutdown()
|
if engine_core := getattr(self, "engine_core", None):
|
||||||
|
engine_core.shutdown()
|
||||||
|
|
||||||
if handler := getattr(self, "output_handler", None):
|
if handler := getattr(self, "output_handler", None):
|
||||||
handler.cancel()
|
handler.cancel()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||||
return GPUExecutor
|
distributed_executor_backend = (
|
||||||
|
vllm_config.parallel_config.distributed_executor_backend)
|
||||||
|
if distributed_executor_backend == "mp":
|
||||||
|
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||||
|
executor_class = MultiprocExecutor
|
||||||
|
else:
|
||||||
|
assert (distributed_executor_backend is None)
|
||||||
|
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||||
|
executor_class = UniprocExecutor
|
||||||
|
return executor_class
|
||||||
|
|
||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -1,12 +1,12 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import pickle
|
import pickle
|
||||||
import queue
|
import queue
|
||||||
|
import signal
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from multiprocessing.sharedctypes import Synchronized
|
from multiprocessing.sharedctypes import Synchronized
|
||||||
from typing import Any, Iterator, List, Tuple, Type, Union
|
from typing import List, Tuple, Type, Union
|
||||||
|
|
||||||
import zmq
|
import zmq
|
||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
@ -20,9 +20,10 @@ from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
|||||||
EngineCoreProfile, EngineCoreRequest,
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType)
|
||||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import PickleEncoder
|
from vllm.v1.serial_utils import PickleEncoder
|
||||||
|
from vllm.v1.utils import make_zmq_socket
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -38,7 +39,7 @@ class EngineCore:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: Type[GPUExecutor],
|
executor_class: Type[Executor],
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
):
|
):
|
||||||
assert vllm_config.model_config.task != "embedding"
|
assert vllm_config.model_config.task != "embedding"
|
||||||
@ -80,7 +81,7 @@ class EngineCore:
|
|||||||
num_gpu_blocks = num_gpu_blocks_override
|
num_gpu_blocks = num_gpu_blocks_override
|
||||||
|
|
||||||
num_cpu_blocks = 0
|
num_cpu_blocks = 0
|
||||||
self.model_executor.initialize_cache(num_gpu_blocks)
|
self.model_executor.initialize(num_gpu_blocks)
|
||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
logger.info(("init engine (profile, create kv cache, "
|
logger.info(("init engine (profile, create kv cache, "
|
||||||
"warmup model) took %.2f seconds"), elapsed)
|
"warmup model) took %.2f seconds"), elapsed)
|
||||||
@ -112,8 +113,11 @@ class EngineCore:
|
|||||||
scheduler_output, output)
|
scheduler_output, output)
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.model_executor.shutdown()
|
||||||
|
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start=True):
|
||||||
self.model_executor.worker.profile(is_start)
|
self.model_executor.profile(is_start)
|
||||||
|
|
||||||
|
|
||||||
class EngineCoreProc(EngineCore):
|
class EngineCoreProc(EngineCore):
|
||||||
@ -124,7 +128,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: Type[GPUExecutor],
|
executor_class: Type[Executor],
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
@ -151,32 +155,9 @@ class EngineCoreProc(EngineCore):
|
|||||||
daemon=True).start()
|
daemon=True).start()
|
||||||
|
|
||||||
# Send Readiness signal to EngineClient.
|
# Send Readiness signal to EngineClient.
|
||||||
with self.make_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||||
ready_socket.send_string(EngineCoreProc.READY_STR)
|
ready_socket.send_string(EngineCoreProc.READY_STR)
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def make_socket(self, path: str, type: Any) -> Iterator[zmq.Socket]:
|
|
||||||
"""Context manager for use """
|
|
||||||
|
|
||||||
ctx = zmq.Context()
|
|
||||||
try:
|
|
||||||
socket = ctx.socket(type)
|
|
||||||
|
|
||||||
if type == zmq.constants.PULL:
|
|
||||||
socket.connect(path)
|
|
||||||
elif type == zmq.constants.PUSH:
|
|
||||||
socket.bind(path)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown Socket Type: {type}")
|
|
||||||
|
|
||||||
yield socket
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.debug("EngineCore had Keyboard Interrupt.")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
ctx.destroy(linger=0)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def wait_for_startup(
|
def wait_for_startup(
|
||||||
proc: BaseProcess,
|
proc: BaseProcess,
|
||||||
@ -209,7 +190,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def make_engine_core_process(
|
def make_engine_core_process(
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: Type[GPUExecutor],
|
executor_class: Type[Executor],
|
||||||
usage_context: UsageContext,
|
usage_context: UsageContext,
|
||||||
input_path: str,
|
input_path: str,
|
||||||
output_path: str,
|
output_path: str,
|
||||||
@ -244,17 +225,38 @@ class EngineCoreProc(EngineCore):
|
|||||||
def run_engine_core(*args, **kwargs):
|
def run_engine_core(*args, **kwargs):
|
||||||
"""Launch EngineCore busy loop in background process."""
|
"""Launch EngineCore busy loop in background process."""
|
||||||
|
|
||||||
|
# Signal handler used for graceful termination.
|
||||||
|
# SystemExit exception is only raised once to allow this and worker
|
||||||
|
# processes to terminate without error
|
||||||
|
shutdown_requested = False
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
nonlocal shutdown_requested
|
||||||
|
if not shutdown_requested:
|
||||||
|
shutdown_requested = True
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
# Either SIGTERM or SIGINT will terminate the engine_core
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
engine_core = None
|
||||||
try:
|
try:
|
||||||
engine_core = EngineCoreProc(*args, **kwargs)
|
engine_core = EngineCoreProc(*args, **kwargs)
|
||||||
engine_core.run_busy_loop()
|
engine_core.run_busy_loop()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except SystemExit:
|
||||||
logger.debug("EngineCore interrupted.")
|
logger.debug("EngineCore interrupted.")
|
||||||
|
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.exception(e)
|
logger.exception(e)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if engine_core is not None:
|
||||||
|
engine_core.shutdown()
|
||||||
|
engine_core = None
|
||||||
|
|
||||||
def run_busy_loop(self):
|
def run_busy_loop(self):
|
||||||
"""Core busy loop of the EngineCore."""
|
"""Core busy loop of the EngineCore."""
|
||||||
|
|
||||||
@ -272,6 +274,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
logger.debug("EngineCore busy loop waiting.")
|
logger.debug("EngineCore busy loop waiting.")
|
||||||
if self.should_shutdown:
|
if self.should_shutdown:
|
||||||
return
|
return
|
||||||
|
except BaseException:
|
||||||
|
raise
|
||||||
|
|
||||||
# 2) Handle any new client requests (Abort or Add).
|
# 2) Handle any new client requests (Abort or Add).
|
||||||
while not self.input_queue.empty():
|
while not self.input_queue.empty():
|
||||||
@ -321,7 +325,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
decoder_add_req = PickleEncoder()
|
decoder_add_req = PickleEncoder()
|
||||||
decoder_abort_req = PickleEncoder()
|
decoder_abort_req = PickleEncoder()
|
||||||
|
|
||||||
with self.make_socket(input_path, zmq.constants.PULL) as socket:
|
with make_zmq_socket(input_path, zmq.constants.PULL) as socket:
|
||||||
while True:
|
while True:
|
||||||
# (RequestType, RequestData)
|
# (RequestType, RequestData)
|
||||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||||
@ -349,7 +353,7 @@ class EngineCoreProc(EngineCore):
|
|||||||
# Reuse send buffer.
|
# Reuse send buffer.
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
|
|
||||||
with self.make_socket(output_path, zmq.constants.PUSH) as socket:
|
with make_zmq_socket(output_path, zmq.constants.PUSH) as socket:
|
||||||
while True:
|
while True:
|
||||||
engine_core_outputs = self.output_queue.get()
|
engine_core_outputs = self.output_queue.get()
|
||||||
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import multiprocessing
|
import multiprocessing
|
||||||
import time
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@ -7,7 +6,7 @@ import zmq
|
|||||||
import zmq.asyncio
|
import zmq.asyncio
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import get_open_zmq_ipc_path
|
from vllm.utils import get_open_zmq_ipc_path, kill_process_tree
|
||||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||||
EngineCoreProfile, EngineCoreRequest,
|
EngineCoreProfile, EngineCoreRequest,
|
||||||
EngineCoreRequestType)
|
EngineCoreRequestType)
|
||||||
@ -99,6 +98,12 @@ class InprocClient(EngineCoreClient):
|
|||||||
def abort_requests(self, request_ids: List[str]) -> None:
|
def abort_requests(self, request_ids: List[str]) -> None:
|
||||||
self.engine_core.abort_requests(request_ids)
|
self.engine_core.abort_requests(request_ids)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.engine_core.shutdown()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.shutdown()
|
||||||
|
|
||||||
async def profile(self, is_start=True) -> None:
|
async def profile(self, is_start=True) -> None:
|
||||||
self.engine_core.profile(is_start)
|
self.engine_core.profile(is_start)
|
||||||
|
|
||||||
@ -163,10 +168,10 @@ class MPClient(EngineCoreClient):
|
|||||||
# Shutdown the process if needed.
|
# Shutdown the process if needed.
|
||||||
if hasattr(self, "proc") and self.proc.is_alive():
|
if hasattr(self, "proc") and self.proc.is_alive():
|
||||||
self.proc.terminate()
|
self.proc.terminate()
|
||||||
|
self.proc.join(5)
|
||||||
|
|
||||||
time.sleep(5)
|
|
||||||
if self.proc.is_alive():
|
if self.proc.is_alive():
|
||||||
self.proc.kill()
|
kill_process_tree(self.proc.pid)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.shutdown()
|
self.shutdown()
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from vllm.usage.usage_lib import UsageContext
|
|||||||
from vllm.v1.engine.core_client import EngineCoreClient
|
from vllm.v1.engine.core_client import EngineCoreClient
|
||||||
from vllm.v1.engine.detokenizer import Detokenizer
|
from vllm.v1.engine.detokenizer import Detokenizer
|
||||||
from vllm.v1.engine.processor import Processor
|
from vllm.v1.engine.processor import Processor
|
||||||
from vllm.v1.executor.gpu_executor import GPUExecutor
|
from vllm.v1.executor.abstract import Executor
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class LLMEngine:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
executor_class: Type[GPUExecutor],
|
executor_class: Type[Executor],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||||
@ -104,10 +104,17 @@ class LLMEngine:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
def _get_executor_cls(cls, vllm_config: VllmConfig):
|
||||||
return GPUExecutor
|
distributed_executor_backend = (
|
||||||
|
vllm_config.parallel_config.distributed_executor_backend)
|
||||||
|
if distributed_executor_backend == "mp":
|
||||||
|
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||||
|
executor_class = MultiprocExecutor
|
||||||
|
else:
|
||||||
|
assert (distributed_executor_backend is None)
|
||||||
|
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||||
|
executor_class = UniprocExecutor
|
||||||
|
|
||||||
def stop_remote_worker_execution_loop(self) -> None:
|
return executor_class
|
||||||
raise NotImplementedError("TP not implemented yet.")
|
|
||||||
|
|
||||||
def get_num_unfinished_requests(self) -> int:
|
def get_num_unfinished_requests(self) -> int:
|
||||||
return self.detokenizer.get_num_unfinished_requests()
|
return self.detokenizer.get_num_unfinished_requests()
|
||||||
|
|||||||
48
vllm/v1/executor/abstract.py
Normal file
48
vllm/v1/executor/abstract.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class Executor(ABC):
|
||||||
|
"""Abstract class for executors."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def initialize(self, num_gpu_blocks: int) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output,
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def profile(self, is_start=True):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def shutdown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check_health(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: str,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: Tuple = (),
|
||||||
|
kwargs: Optional[Dict] = None) -> []:
|
||||||
|
raise NotImplementedError
|
||||||
375
vllm/v1/executor/multiproc_executor.py
Normal file
375
vllm/v1/executor/multiproc_executor.py
Normal file
@ -0,0 +1,375 @@
|
|||||||
|
import atexit
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import signal
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
from multiprocessing.process import BaseProcess
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
|
destroy_model_parallel)
|
||||||
|
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||||
|
MessageQueue)
|
||||||
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
|
_add_prefix, get_mp_context, set_multiprocessing_worker_envs)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||||
|
get_open_zmq_ipc_path)
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.utils import make_zmq_socket
|
||||||
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
POLLING_TIMEOUT_MS = 5000
|
||||||
|
POLLING_TIMEOUT_S = POLLING_TIMEOUT_MS // 1000
|
||||||
|
|
||||||
|
|
||||||
|
class MultiprocExecutor:
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
|
# Call self.shutdown at exit to clean up
|
||||||
|
# and ensure workers will be terminated.
|
||||||
|
atexit.register(self.shutdown)
|
||||||
|
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.parallel_config = vllm_config.parallel_config
|
||||||
|
|
||||||
|
self.world_size = self.parallel_config.world_size
|
||||||
|
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||||
|
assert self.world_size == tensor_parallel_size, (
|
||||||
|
f"world_size ({self.world_size}) must be equal to the "
|
||||||
|
f"tensor_parallel_size ({tensor_parallel_size}). "
|
||||||
|
f"Pipeline parallelism is not yet implemented in v1")
|
||||||
|
|
||||||
|
# Set multiprocessing envs that are common to V0 and V1
|
||||||
|
set_multiprocessing_worker_envs(self.parallel_config)
|
||||||
|
|
||||||
|
# Multiprocessing-based executor does not support multi-node setting.
|
||||||
|
# Since it only works for single node, we can use the loopback address
|
||||||
|
# 127.0.0.1 for communication.
|
||||||
|
distributed_init_method = get_distributed_init_method(
|
||||||
|
"127.0.0.1", get_open_port())
|
||||||
|
|
||||||
|
# Initialize worker and set up message queues for SchedulerOutputs
|
||||||
|
# and ModelRunnerOutputs
|
||||||
|
self.rpc_broadcast_mq = MessageQueue(self.world_size, self.world_size)
|
||||||
|
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||||
|
|
||||||
|
# Create workers
|
||||||
|
self.workers: List[WorkerProcHandle] = []
|
||||||
|
for rank in range(self.world_size):
|
||||||
|
worker = WorkerProc.make_worker_process(vllm_config, rank, rank,
|
||||||
|
distributed_init_method,
|
||||||
|
scheduler_output_handle)
|
||||||
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||||
|
# Must be kept consistent with the WorkerProc
|
||||||
|
self.rpc_broadcast_mq.wait_until_ready()
|
||||||
|
for w in self.workers:
|
||||||
|
w.worker_response_mq.wait_until_ready()
|
||||||
|
|
||||||
|
def initialize(self, num_gpu_blocks: int) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the KV caches and begin the model execution loop of the
|
||||||
|
underlying workers.
|
||||||
|
"""
|
||||||
|
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
|
||||||
|
self.collective_rpc("compile_or_warm_up_model")
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Determine the number of available KV blocks by invoking the
|
||||||
|
underlying worker.
|
||||||
|
"""
|
||||||
|
num_blocks = self.collective_rpc("determine_num_available_blocks")
|
||||||
|
|
||||||
|
# Since we use a shared centralized controller, we take the minimum
|
||||||
|
# number of blocks across all workers to make sure all the memory
|
||||||
|
# operators can be applied to all workers.
|
||||||
|
num_gpu_blocks = min(b[0] for b in num_blocks)
|
||||||
|
num_cpu_blocks = min(b[1] for b in num_blocks)
|
||||||
|
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def collective_rpc(self,
|
||||||
|
method: str,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
args: Tuple = (),
|
||||||
|
kwargs: Optional[Dict] = None) -> []:
|
||||||
|
"""
|
||||||
|
Execute an RPC call on workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: Name of the worker method to execute
|
||||||
|
timeout: Maximum time in seconds to wait for execution. Rases a
|
||||||
|
TimeoutError on timeout. None means wait indefinitely.
|
||||||
|
args: Positional arguments to pass to the worker method
|
||||||
|
kwargs: Keyword arguments to pass to the worker method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results from each worker
|
||||||
|
"""
|
||||||
|
start_time = time.monotonic()
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.rpc_broadcast_mq.enqueue((method, args, kwargs))
|
||||||
|
|
||||||
|
responses = [None] * self.world_size
|
||||||
|
for w in self.workers:
|
||||||
|
dequeue_timeout = timeout - (time.monotonic() - start_time()
|
||||||
|
) if timeout is not None else None
|
||||||
|
status, result = w.worker_response_mq.dequeue(
|
||||||
|
timeout=dequeue_timeout)
|
||||||
|
|
||||||
|
if status != WorkerProc.ResponseStatus.SUCCESS:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
raise result
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Worker failed")
|
||||||
|
|
||||||
|
responses[w.rank] = result
|
||||||
|
|
||||||
|
return responses
|
||||||
|
except TimeoutError as e:
|
||||||
|
raise TimeoutError(f"RPC call to {method} timed out.") from e
|
||||||
|
except Exception as e:
|
||||||
|
# Re-raise any other exceptions
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
scheduler_output,
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
model_output = self.collective_rpc("execute_model",
|
||||||
|
args=(scheduler_output, ))[0]
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def profile(self, is_start=True):
|
||||||
|
self.collective_rpc("profile", args=(is_start, ))
|
||||||
|
return
|
||||||
|
|
||||||
|
def _ensure_worker_termination(self):
|
||||||
|
"""Ensure that all worker processes are terminated. Assumes workers have
|
||||||
|
received termination requests. Waits for processing, then sends
|
||||||
|
termination and kill signals if needed."""
|
||||||
|
|
||||||
|
def wait_for_termination(procs, timeout):
|
||||||
|
start_time = time.time()
|
||||||
|
while time.time() - start_time < timeout:
|
||||||
|
if all(not proc.is_alive() for proc in procs):
|
||||||
|
return True
|
||||||
|
time.sleep(0.1)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Send SIGTERM if still running
|
||||||
|
active_procs = [w.proc for w in self.workers if w.proc.is_alive()]
|
||||||
|
self.workers = None
|
||||||
|
for p in active_procs:
|
||||||
|
p.terminate()
|
||||||
|
if wait_for_termination(active_procs, 4):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send SIGKILL if still running
|
||||||
|
active_procs = [p for p in active_procs if p.is_alive()]
|
||||||
|
for p in active_procs:
|
||||||
|
p.kill()
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
"""Properly shut down the executor and its workers"""
|
||||||
|
if (hasattr(self, 'workers') and self.workers is not None):
|
||||||
|
for w in self.workers: #TODO: not sure if needed
|
||||||
|
w.worker_response_mq = None
|
||||||
|
self._ensure_worker_termination()
|
||||||
|
|
||||||
|
self.rpc_broadcast_mq = None
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
self.collective_rpc("check_health", timeout=10)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerProcHandle:
|
||||||
|
proc: BaseProcess
|
||||||
|
rank: int
|
||||||
|
ready_path: str
|
||||||
|
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerProc:
|
||||||
|
"""Wrapper that runs one Worker in a separate process."""
|
||||||
|
|
||||||
|
READY_STR = "READY"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
input_shm_handle: Handle,
|
||||||
|
ready_path: str,
|
||||||
|
):
|
||||||
|
self.rank = rank
|
||||||
|
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
|
||||||
|
wrapper.init_worker(vllm_config, local_rank, rank,
|
||||||
|
distributed_init_method)
|
||||||
|
self.worker = wrapper.worker
|
||||||
|
|
||||||
|
pid = os.getpid()
|
||||||
|
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
||||||
|
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
|
||||||
|
|
||||||
|
# Initialize MessageQueue for receiving SchedulerOutput
|
||||||
|
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||||
|
input_shm_handle, self.worker.rank)
|
||||||
|
|
||||||
|
# Initializes a message queue for sending the model output
|
||||||
|
self.worker_response_mq = MessageQueue(1, 1)
|
||||||
|
worker_response_mq_handle = self.worker_response_mq.export_handle()
|
||||||
|
|
||||||
|
# Send Readiness signal to EngineCore process.
|
||||||
|
with make_zmq_socket(ready_path, zmq.constants.PUSH) as ready_socket:
|
||||||
|
payload = pickle.dumps(worker_response_mq_handle,
|
||||||
|
protocol=pickle.HIGHEST_PROTOCOL)
|
||||||
|
ready_socket.send_string(WorkerProc.READY_STR)
|
||||||
|
ready_socket.send(payload)
|
||||||
|
|
||||||
|
self.worker.initialize()
|
||||||
|
self.worker.load_model()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_worker_process(
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
input_shm_handle, # Receive SchedulerOutput
|
||||||
|
) -> WorkerProcHandle:
|
||||||
|
context = get_mp_context()
|
||||||
|
|
||||||
|
# ZMQ path for worker to send ready message and shm_broadcast handle
|
||||||
|
# back to core process.
|
||||||
|
ready_path = get_open_zmq_ipc_path()
|
||||||
|
|
||||||
|
process_kwargs = {
|
||||||
|
"vllm_config": vllm_config,
|
||||||
|
"local_rank": local_rank,
|
||||||
|
"rank": rank,
|
||||||
|
"distributed_init_method": distributed_init_method,
|
||||||
|
"input_shm_handle": input_shm_handle,
|
||||||
|
"ready_path": ready_path,
|
||||||
|
}
|
||||||
|
# Run EngineCore busy loop in background process.
|
||||||
|
proc = context.Process(target=WorkerProc.worker_main,
|
||||||
|
kwargs=process_kwargs,
|
||||||
|
daemon=True)
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
# Wait for startup
|
||||||
|
worker_response_mq_handle = WorkerProc.wait_for_startup(
|
||||||
|
proc, ready_path)
|
||||||
|
|
||||||
|
worker_response_mq = MessageQueue.create_from_handle(
|
||||||
|
worker_response_mq_handle, 0)
|
||||||
|
|
||||||
|
return WorkerProcHandle(proc, rank, ready_path, worker_response_mq)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.rpc_broadcast_mq = None
|
||||||
|
self.worker_response_mq = None
|
||||||
|
destroy_model_parallel()
|
||||||
|
destroy_distributed_environment()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def worker_main(*args, **kwargs):
|
||||||
|
""" Worker initialization and execution loops.
|
||||||
|
This runs a background process """
|
||||||
|
|
||||||
|
# Signal handler used for graceful termination.
|
||||||
|
# SystemExit exception is only raised once to allow this and worker
|
||||||
|
# processes to terminate without error
|
||||||
|
shutdown_requested = False
|
||||||
|
|
||||||
|
def signal_handler(signum, frame):
|
||||||
|
nonlocal shutdown_requested
|
||||||
|
if not shutdown_requested:
|
||||||
|
shutdown_requested = True
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
# Either SIGTERM or SIGINT will terminate the worker
|
||||||
|
signal.signal(signal.SIGTERM, signal_handler)
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
worker = None
|
||||||
|
try:
|
||||||
|
worker = WorkerProc(*args, **kwargs)
|
||||||
|
|
||||||
|
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||||
|
# Must be kept consistent with the Executor
|
||||||
|
worker.rpc_broadcast_mq.wait_until_ready()
|
||||||
|
worker.worker_response_mq.wait_until_ready()
|
||||||
|
|
||||||
|
worker.worker_busy_loop()
|
||||||
|
|
||||||
|
except SystemExit:
|
||||||
|
logger.debug("Worker interrupted.")
|
||||||
|
|
||||||
|
except BaseException as e:
|
||||||
|
logger.exception(e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up once worker exits busy loop
|
||||||
|
if worker is not None:
|
||||||
|
worker.shutdown()
|
||||||
|
worker = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def wait_for_startup(
|
||||||
|
proc: BaseProcess,
|
||||||
|
ready_path: str,
|
||||||
|
) -> Optional[Handle]:
|
||||||
|
"""Wait until the Worker is ready."""
|
||||||
|
with make_zmq_socket(ready_path, zmq.constants.PULL) as socket:
|
||||||
|
|
||||||
|
# Wait for Worker to send READY.
|
||||||
|
while socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||||
|
logger.debug("Waiting for WorkerProc to startup.")
|
||||||
|
|
||||||
|
if not proc.is_alive():
|
||||||
|
raise RuntimeError("WorkerProc failed to start.")
|
||||||
|
|
||||||
|
message = socket.recv_string()
|
||||||
|
assert message == WorkerProc.READY_STR
|
||||||
|
handle_frame = socket.recv(copy=False)
|
||||||
|
handle = pickle.loads(handle_frame.buffer)
|
||||||
|
return handle
|
||||||
|
|
||||||
|
class ResponseStatus(Enum):
|
||||||
|
SUCCESS = auto()
|
||||||
|
FAILURE = auto()
|
||||||
|
|
||||||
|
def worker_busy_loop(self):
|
||||||
|
"""Main busy loop for Multiprocessing Workers"""
|
||||||
|
while True:
|
||||||
|
method, args, kwargs = self.rpc_broadcast_mq.dequeue()
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = getattr(self.worker, method)(*args, **kwargs)
|
||||||
|
except BaseException as e:
|
||||||
|
self.worker_response_mq.enqueue(
|
||||||
|
(WorkerProc.ResponseStatus.FAILURE, e))
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.worker_response_mq.enqueue(
|
||||||
|
(WorkerProc.ResponseStatus.SUCCESS, output))
|
||||||
@ -10,7 +10,7 @@ from vllm.v1.worker.gpu_worker import Worker
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GPUExecutor:
|
class UniprocExecutor:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -54,7 +54,7 @@ class GPUExecutor:
|
|||||||
"""
|
"""
|
||||||
return self.worker.determine_num_available_blocks()
|
return self.worker.determine_num_available_blocks()
|
||||||
|
|
||||||
def initialize_cache(self, num_gpu_blocks: int) -> None:
|
def initialize(self, num_gpu_blocks: int) -> None:
|
||||||
"""Initialize the KV cache by invoking the underlying worker.
|
"""Initialize the KV cache by invoking the underlying worker.
|
||||||
"""
|
"""
|
||||||
# NOTE: This is logged in the executor because there can be >1 worker
|
# NOTE: This is logged in the executor because there can be >1 worker
|
||||||
@ -71,7 +71,13 @@ class GPUExecutor:
|
|||||||
output = self.worker.execute_model(scheduler_output)
|
output = self.worker.execute_model(scheduler_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def profile(self, is_start: bool = True):
|
||||||
|
self.worker.profile(is_start)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.worker = None
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
# GPUExecutor will always be healthy as long as
|
# UniprocExecutor will always be healthy as long as
|
||||||
# it's running.
|
# it's running.
|
||||||
return
|
return
|
||||||
@ -8,7 +8,7 @@ import torch
|
|||||||
class SamplerOutput:
|
class SamplerOutput:
|
||||||
|
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
sampled_token_ids: torch.Tensor
|
sampled_token_ids: List[int]
|
||||||
|
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
logprob_token_ids: Optional[torch.Tensor]
|
logprob_token_ids: Optional[torch.Tensor]
|
||||||
@ -20,6 +20,8 @@ class SamplerOutput:
|
|||||||
prompt_logprobs: Optional[torch.Tensor]
|
prompt_logprobs: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
# ModelRunnerOutput is serialized and sent to the scheduler process.
|
||||||
|
# This is expensive for torch.Tensor so prefer to use List instead.
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelRunnerOutput:
|
class ModelRunnerOutput:
|
||||||
|
|
||||||
@ -29,7 +31,7 @@ class ModelRunnerOutput:
|
|||||||
req_id_to_index: Dict[str, int]
|
req_id_to_index: Dict[str, int]
|
||||||
|
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
sampled_token_ids_cpu: torch.Tensor
|
sampled_token_ids: List[int]
|
||||||
|
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
logprob_token_ids_cpu: Optional[torch.Tensor]
|
logprob_token_ids_cpu: Optional[torch.Tensor]
|
||||||
|
|||||||
@ -37,8 +37,9 @@ class Sampler(nn.Module):
|
|||||||
topk_logprobs = None
|
topk_logprobs = None
|
||||||
topk_indices = None
|
topk_indices = None
|
||||||
|
|
||||||
|
# NOTE: CPU-GPU synchronization happens here.
|
||||||
sampler_output = SamplerOutput(
|
sampler_output = SamplerOutput(
|
||||||
sampled_token_ids=sampled,
|
sampled_token_ids=sampled.tolist(),
|
||||||
logprob_token_ids=topk_indices,
|
logprob_token_ids=topk_indices,
|
||||||
logprobs=topk_logprobs,
|
logprobs=topk_logprobs,
|
||||||
prompt_logprob_token_ids=None,
|
prompt_logprob_token_ids=None,
|
||||||
|
|||||||
@ -1,4 +1,11 @@
|
|||||||
from typing import Generic, List, TypeVar, overload
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Generic, Iterator, List, TypeVar, overload
|
||||||
|
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -62,3 +69,27 @@ class ConstantList(Generic[T]):
|
|||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._x)
|
return len(self._x)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||||
|
"""Context manager for a ZMQ socket"""
|
||||||
|
|
||||||
|
ctx = zmq.Context()
|
||||||
|
try:
|
||||||
|
socket = ctx.socket(type)
|
||||||
|
|
||||||
|
if type == zmq.constants.PULL:
|
||||||
|
socket.connect(path)
|
||||||
|
elif type == zmq.constants.PUSH:
|
||||||
|
socket.bind(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown Socket Type: {type}")
|
||||||
|
|
||||||
|
yield socket
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.debug("Worker had Keyboard Interrupt.")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
ctx.destroy(linger=0)
|
||||||
|
|||||||
@ -34,6 +34,7 @@ class GPUModelRunner:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||||
):
|
):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
@ -43,7 +44,6 @@ class GPUModelRunner:
|
|||||||
self.load_config = vllm_config.load_config
|
self.load_config = vllm_config.load_config
|
||||||
self.parallel_config = vllm_config.parallel_config
|
self.parallel_config = vllm_config.parallel_config
|
||||||
self.scheduler_config = vllm_config.scheduler_config
|
self.scheduler_config = vllm_config.scheduler_config
|
||||||
self.device_config = vllm_config.device_config
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
self.prompt_adapter_config = vllm_config.prompt_adapter_config
|
||||||
self.observability_config = vllm_config.observability_config
|
self.observability_config = vllm_config.observability_config
|
||||||
@ -52,7 +52,7 @@ class GPUModelRunner:
|
|||||||
cache_config = self.cache_config
|
cache_config = self.cache_config
|
||||||
scheduler_config = self.scheduler_config
|
scheduler_config = self.scheduler_config
|
||||||
parallel_config = self.parallel_config
|
parallel_config = self.parallel_config
|
||||||
self.device = self.device_config.device
|
self.device = device
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
@ -477,9 +477,7 @@ class GPUModelRunner:
|
|||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
# NOTE: CPU-GPU synchronization happens here.
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
|
|
||||||
sampled_token_ids_list = sampled_token_ids.tolist()
|
|
||||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||||
# the requests one by one. Optimize.
|
# the requests one by one. Optimize.
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -490,7 +488,7 @@ class GPUModelRunner:
|
|||||||
assert seq_len <= req_state.num_tokens
|
assert seq_len <= req_state.num_tokens
|
||||||
if seq_len == req_state.num_tokens:
|
if seq_len == req_state.num_tokens:
|
||||||
# Append the sampled token to the output token ids.
|
# Append the sampled token to the output token ids.
|
||||||
token_id = sampled_token_ids_list[i]
|
token_id = sampled_token_ids[i]
|
||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||||
req_state.output_token_ids.append(token_id)
|
req_state.output_token_ids.append(token_id)
|
||||||
else:
|
else:
|
||||||
@ -512,7 +510,7 @@ class GPUModelRunner:
|
|||||||
model_runner_output = ModelRunnerOutput(
|
model_runner_output = ModelRunnerOutput(
|
||||||
req_ids=self.input_batch.req_ids[:num_reqs],
|
req_ids=self.input_batch.req_ids[:num_reqs],
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids_cpu=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
logprob_token_ids_cpu=logprob_token_ids,
|
logprob_token_ids_cpu=logprob_token_ids,
|
||||||
logprobs_cpu=logprobs,
|
logprobs_cpu=logprobs,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
|
||||||
|
from vllm.v1.core.scheduler import SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
@ -56,7 +57,6 @@ class Worker:
|
|||||||
from vllm.utils import init_cached_hf_modules
|
from vllm.utils import init_cached_hf_modules
|
||||||
init_cached_hf_modules()
|
init_cached_hf_modules()
|
||||||
|
|
||||||
self.model_runner = GPUModelRunner(vllm_config)
|
|
||||||
# Torch profiler. Enabled and configured through env vars:
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
@ -103,6 +103,9 @@ class Worker:
|
|||||||
# Set random seed.
|
# Set random seed.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
# Construct the model runner
|
||||||
|
self.model_runner = GPUModelRunner(self.vllm_config, self.device)
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
@ -198,7 +201,7 @@ class Worker:
|
|||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
) -> ModelRunnerOutput:
|
) -> ModelRunnerOutput:
|
||||||
output = self.model_runner.execute_model(scheduler_output)
|
output = self.model_runner.execute_model(scheduler_output)
|
||||||
# TODO(woosuk): Send the output to the engine process.
|
return output if self.rank == 0 else None
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def profile(self, is_start=True):
|
def profile(self, is_start=True):
|
||||||
@ -209,6 +212,10 @@ class Worker:
|
|||||||
else:
|
else:
|
||||||
self.profiler.stop()
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def check_health(self) -> None:
|
||||||
|
# worker will always be healthy as long as it's running.
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def init_worker_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user