Use NCCL instead of ray for control-plane communication to remove serialization overhead (#2221)

This commit is contained in:
Zhuohan Li 2024-01-04 03:30:22 +08:00 committed by GitHub
parent 1066cbd152
commit fd4ea8ef5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 524 additions and 262 deletions

View File

@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ positions: torch.Tensor, + positions: torch.Tensor,
+ kv_caches: List[KVCache], + kv_caches: List[KVCache],
+ input_metadata: InputMetadata, + input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]], +) -> Optional[SamplerOutput]:
+) -> SamplerOutput:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. 1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. 2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
.. note:: .. note::
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.

View File

@ -3,8 +3,6 @@ typing-extensions>=4.8.0
starlette starlette
psutil psutil
ray >= 2.5.1 ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
tokenizers>=0.15.0 tokenizers>=0.15.0

View File

@ -1,8 +1,6 @@
ninja # For faster builds. ninja # For faster builds.
psutil psutil
ray >= 2.5.1 ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer. sentencepiece # Required for LLaMA tokenizer.
numpy numpy
torch == 2.1.2 torch == 2.1.2

View File

@ -8,11 +8,11 @@ import pytest
import requests import requests
def _query_server(prompt: str) -> dict: def _query_server(prompt: str, max_tokens: int = 5) -> dict:
response = requests.post("http://localhost:8000/generate", response = requests.post("http://localhost:8000/generate",
json={ json={
"prompt": prompt, "prompt": prompt,
"max_tokens": 100, "max_tokens": max_tokens,
"temperature": 0, "temperature": 0,
"ignore_eos": True "ignore_eos": True
}) })
@ -20,6 +20,10 @@ def _query_server(prompt: str) -> dict:
return response.json() return response.json()
def _query_server_long(prompt: str) -> dict:
return _query_server(prompt, max_tokens=500)
@pytest.fixture @pytest.fixture
def api_server(): def api_server():
script_path = Path(__file__).parent.joinpath( script_path = Path(__file__).parent.joinpath(
@ -68,10 +72,11 @@ def test_api_server(api_server):
for result in pool.map(_query_server, prompts): for result in pool.map(_query_server, prompts):
assert result assert result
with Pool(32) as pool:
# Cancel requests # Cancel requests
prompts = ["canceled requests"] * 100 prompts = ["canceled requests"] * 100
pool.map_async(_query_server, prompts) pool.map_async(_query_server_long, prompts)
time.sleep(0.001) time.sleep(0.01)
pool.terminate() pool.terminate()
pool.join() pool.join()

View File

@ -49,12 +49,13 @@ def test_copy_blocks(
src_blocks = random.sample(range(num_blocks), num_mappings) src_blocks = random.sample(range(num_blocks), num_mappings)
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
block_mapping = {} copy_src = []
copy_dst = []
for i in range(num_mappings): for i in range(num_mappings):
src = src_blocks[i] copy_src.append(src_blocks[i])
dst1 = dst_blocks[2 * i] copy_dst.append(dst_blocks[2 * i])
dst2 = dst_blocks[2 * i + 1] copy_src.append(src_blocks[i])
block_mapping[src] = [dst1, dst2] copy_dst.append(dst_blocks[2 * i + 1])
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
@ -66,15 +67,14 @@ def test_copy_blocks(
cloned_value_caches = [value_cache.clone() for value_cache in value_caches] cloned_value_caches = [value_cache.clone() for value_cache in value_caches]
# Call the copy blocks kernel. # Call the copy blocks kernel.
cache_ops.copy_blocks(key_caches, value_caches, block_mapping) cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst)
# Run the reference implementation. # Run the reference implementation.
for src, dsts in block_mapping.items(): for src, dst in zip(copy_src, copy_dst):
for dst in dsts: for cloned_key_cache in cloned_key_caches:
for cloned_key_cache in cloned_key_caches: cloned_key_cache[dst].copy_(cloned_key_cache[src])
cloned_key_cache[dst].copy_(cloned_key_cache[src]) for cloned_value_cache in cloned_value_caches:
for cloned_value_cache in cloned_value_caches: cloned_value_cache[dst].copy_(cloned_value_cache[src])
cloned_value_cache[dst].copy_(cloned_value_cache[src])
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):

View File

@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, _ = model_runner._prepare_prompt( input_tokens, input_positions, _, return_prompt_lens = (
seq_group_metadata_list) model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens) prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len) assert input_tokens.shape == (batch_size, max_seq_len)

View File

@ -185,14 +185,21 @@ class _AsyncLLMEngine(LLMEngine):
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
# Execute the model. if not scheduler_outputs.is_empty():
output = (await self._run_workers_async( # Execute the model.
"execute_model", all_outputs = await self._run_workers_async(
seq_group_metadata_list=seq_group_metadata_list, "execute_model",
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, driver_kwargs={
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, "seq_group_metadata_list": seq_group_metadata_list,
blocks_to_copy=scheduler_outputs.blocks_to_copy, "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
)) if not scheduler_outputs.is_empty() else [] "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
@ -200,30 +207,29 @@ class _AsyncLLMEngine(LLMEngine):
self, self,
method: str, method: str,
*args, *args,
get_all_outputs: bool = False, driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
coros = [] coros = []
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))
# Run the ray workers asynchronously.
for worker in self.workers: for worker in self.workers:
if self.parallel_config.worker_use_ray: coros.append(worker.execute_method.remote(method, *args, **kwargs))
coros.append(
worker.execute_method.remote(method, *args, **kwargs))
else:
executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))
all_outputs = await asyncio.gather(*coros) all_outputs = await asyncio.gather(*coros)
return all_outputs
if get_all_outputs:
return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
class AsyncLLMEngine: class AsyncLLMEngine:
@ -488,13 +494,12 @@ class AsyncLLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster( placement_group = initialize_cluster(parallel_config,
parallel_config, engine_args.engine_use_ray) engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray, engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, *engine_configs,
distributed_init_method,
placement_group, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,

View File

@ -1,8 +1,9 @@
import copy import copy
from collections import defaultdict
import os import os
import time import time
from functools import partial from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union Union)
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
@ -17,10 +18,9 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus) SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
if ray: if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
@ -53,8 +53,6 @@ class LLMEngine:
management. management.
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution. placement_group: Ray placement group for distributed execution.
Required for distributed execution. Required for distributed execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
@ -66,7 +64,6 @@ class LLMEngine:
cache_config: CacheConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str,
placement_group: Optional["PlacementGroup"], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
@ -111,7 +108,7 @@ class LLMEngine:
os.environ["RAY_USAGE_STATS_ENABLED"] = "0" os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
else: else:
self._init_workers(distributed_init_method) self._init_workers()
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
@ -126,7 +123,7 @@ class LLMEngine:
# List of (timestamp, num_tokens) # List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = [] self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str): def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
@ -135,70 +132,122 @@ class LLMEngine:
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = [] self.workers: List[Worker] = []
worker = Worker( distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
self.driver_worker = Worker(
self.model_config, self.model_config,
self.parallel_config, self.parallel_config,
self.scheduler_config, self.scheduler_config,
0, local_rank=0,
distributed_init_method, rank=0,
) distributed_init_method=distributed_init_method,
self.workers.append(worker) is_driver_worker=True,
self._run_workers(
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
) )
self._run_workers("init_model")
self._run_workers("load_model")
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
self.driver_dummy_worker: RayWorkerVllm = None
self.workers: List[RayWorkerVllm] = []
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization
else:
num_gpus = 1
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code)
self.workers.append(worker)
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config) parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config) scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True, for rank, (worker, (node_id,
worker_init_fn=lambda: Worker( _)) in enumerate(zip(self.workers,
model_config, worker_node_and_gpu_ids),
parallel_config, start=1):
scheduler_config, local_rank = node_workers[node_id].index(rank)
None, worker.init_worker.remote(
None, lambda rank=rank, local_rank=local_rank: Worker(
)) model_config,
self._run_workers( parallel_config,
"init_model", scheduler_config,
get_all_outputs=True, local_rank,
rank,
distributed_init_method,
))
driver_rank = 0
driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config,
parallel_config,
scheduler_config,
driver_local_rank,
driver_rank,
distributed_init_method,
is_driver_worker=True,
) )
self._run_workers("init_model")
self._run_workers( self._run_workers(
"load_model", "load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers, max_parallel_loading_workers,
) )
@ -212,7 +261,6 @@ class LLMEngine:
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers(
"profile_num_available_blocks", "profile_num_available_blocks",
get_all_outputs=True,
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes, cpu_swap_space=self.cache_config.swap_space_bytes,
@ -256,11 +304,9 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster( placement_group = initialize_cluster(parallel_config)
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, engine = cls(*engine_configs,
distributed_init_method,
placement_group, placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
@ -577,14 +623,21 @@ class LLMEngine:
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
# Execute the model. if not scheduler_outputs.is_empty():
output = self._run_workers( # Execute the model.
"execute_model", all_outputs = self._run_workers(
seq_group_metadata_list=seq_group_metadata_list, "execute_model",
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, driver_kwargs={
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, "seq_group_metadata_list": seq_group_metadata_list,
blocks_to_copy=scheduler_outputs.blocks_to_copy, "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
) if not scheduler_outputs.is_empty() else [] "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
@ -712,53 +765,38 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
def _run_workers_in_batch(
self,
workers,
method: str,
*args,
**kwargs,
):
all_outputs = []
for worker in workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
return all_outputs
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
get_all_outputs: bool = False, driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = []
if max_concurrent_workers: if max_concurrent_workers:
work_groups = [ raise NotImplementedError(
self.workers[i:i + max_concurrent_workers] "max_concurrent_workers is not supported yet.")
for i in range(0, len(self.workers), max_concurrent_workers)
]
else:
work_groups = [self.workers]
for workers in work_groups: # Start the ray workers first.
all_outputs.extend( ray_worker_outputs = [
self._run_workers_in_batch(workers, method, *args, **kwargs)) worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if get_all_outputs: if driver_args is None:
return all_outputs driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Make sure all workers have the same results. # Start the driver worker after all the ray workers.
output = all_outputs[0] driver_worker_output = getattr(self.driver_worker,
for other_output in all_outputs[1:]: method)(*driver_args, **driver_kwargs)
assert output == other_output
return output # Get the results of the ray workers.
if self.workers:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs

View File

@ -1,16 +1,15 @@
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, List, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_open_port, is_hip from vllm.utils import is_hip, set_cuda_visible_devices, get_ip
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
import ray import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorkerVllm(TorchDistributedWorker): class RayWorkerVllm:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
@ -30,12 +29,22 @@ try:
executor = getattr(self, method) executor = getattr(self, method)
return executor(*args, **kwargs) return executor(*args, **kwargs)
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray pandas pyarrow`.") "`pip install ray pandas pyarrow`.")
ray = None ray = None
TorchDistributedWorker = None
RayWorkerVllm = None RayWorkerVllm = None
if TYPE_CHECKING: if TYPE_CHECKING:
@ -75,13 +84,11 @@ def initialize_cluster(
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. assert parallel_config.world_size == 1, (
port = get_open_port() "Ray is required if parallel_config.world_size > 1.")
# We need to setup the distributed init method to make sure return None
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
return distributed_init_method, None
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group() current_placement_group = ray.util.get_current_placement_group()
if current_placement_group: if current_placement_group:
# We are in a placement group # We are in a placement group
@ -106,12 +113,12 @@ def initialize_cluster(
"The number of required GPUs exceeds the total number of " "The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.") "available GPUs in the cluster.")
# Create a new placement group # Create a new placement group
current_placement_group = ray.util.placement_group([{ placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
"GPU": 1 current_placement_group = ray.util.placement_group(
}] * parallel_config.world_size) placement_group_specs)
# Wait until PG is ready - this will block until all # Wait until PG is ready - this will block until all
# requested resources are available, and will timeout # requested resources are available, and will timeout
# if they cannot be provisioned. # if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800) ray.get(current_placement_group.ready(), timeout=1800)
return None, current_placement_group return current_placement_group

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import Optional
import torch import torch
@ -16,28 +16,27 @@ class InputMetadata:
def __init__( def __init__(
self, self,
prompt_lens: List[int], is_prompt: bool,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
max_context_len: Optional[int], max_context_len: Optional[int],
context_lens: Optional[torch.Tensor], context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor], block_tables: Optional[torch.Tensor],
use_cuda_graph: bool, use_cuda_graph: bool,
) -> None: ) -> None:
self.prompt_lens = prompt_lens self.is_prompt = is_prompt
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.block_tables = block_tables self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph self.use_cuda_graph = use_cuda_graph
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
self.attn_bias = None self.attn_bias = None
def __repr__(self) -> str: def __repr__(self) -> str:
return ("InputMetadata(" return ("InputMetadata("
f"prompt_lens={self.prompt_lens}, " f"is_prompt={self.is_prompt}, "
f"max_context_len={self.max_context_len}, " f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, " f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, " f"context_lens={self.context_lens}, "

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
@ -37,7 +37,7 @@ class Sampler(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
@ -45,6 +45,14 @@ class Sampler(nn.Module):
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size) self.vocab_size)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply logits processors (if any). # Apply logits processors (if any).
@ -92,14 +100,15 @@ class Sampler(nn.Module):
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor], embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor: vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None: if embedding_bias is not None:
logits += embedding_bias logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
logits = logits[:, :vocab_size] if logits is not None:
logits = logits[:, :vocab_size]
return logits return logits

View File

@ -298,7 +298,7 @@ class AquilaForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -313,7 +313,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -290,7 +290,7 @@ class BloomForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -349,7 +349,7 @@ class ChatGLMForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -394,7 +394,7 @@ class FalconForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -235,7 +235,7 @@ class GPT2LMHeadModel(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -254,7 +254,7 @@ class GPTBigCodeForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -240,7 +240,7 @@ class GPTJForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return next_tokens return next_tokens

View File

@ -255,7 +255,7 @@ class GPTNeoXForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.embed_out.weight, hidden_states, next_tokens = self.sampler(self.embed_out.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
self, self,
hidden_states: Optional[torch.Tensor], hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
head = self.lm_head.linear head = self.lm_head.linear
next_tokens = self.sampler(head.weight, hidden_states, next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias) sampling_metadata, head.bias)

View File

@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens

View File

@ -1,6 +1,7 @@
import torch import torch
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group, get_tensor_model_parallel_group,
) )
@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
(world_size * input_size[dim], ) + (world_size * input_size[dim], ) +
input_size[dim + 1:]) input_size[dim + 1:])
return output_tensor return output_tensor
def tensor_model_parallel_gather(input_, dst=0, dim=-1):
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_, src=0):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src)
return input_
def broadcast_object_list(obj_list, src=0):
"""Broadcast the input object list."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src)
return obj_list

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import torch import torch
@ -18,24 +18,29 @@ class SamplingMetadata:
seq_data: Seq_id -> SequenceData. seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling. selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample. categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
""" """
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]], seq_groups: Optional[List[Tuple[List[int], SamplingParams]]],
seq_data: Dict[int, SequenceData], seq_data: Optional[Dict[int, SequenceData]],
prompt_lens: List[int], prompt_lens: Optional[List[int]],
selected_token_indices: torch.Tensor, selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor], categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]],
perform_sampling: bool = True,
) -> None: ) -> None:
self.seq_groups = seq_groups self.seq_groups = seq_groups
self.seq_data = seq_data self.seq_data = seq_data
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices self.categorized_sample_indices = categorized_sample_indices
self.perform_sampling = perform_sampling
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -44,7 +49,8 @@ class SamplingMetadata:
f"seq_data={self.seq_data}, " f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, " f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, " f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices})") f"categorized_sample_indices={self.categorized_sample_indices}), "
f"perform_sampling={self.perform_sampling})")
@dataclass @dataclass

View File

@ -1,7 +1,9 @@
import enum import enum
import os
import socket import socket
import uuid import uuid
from platform import uname from platform import uname
from typing import List
import psutil import psutil
import torch import torch
@ -55,7 +57,15 @@ def in_wsl() -> bool:
return "microsoft" in " ".join(uname()).lower() return "microsoft" in " ".join(uname()).lower()
def get_open_port(): def get_ip() -> str:
return socket.gethostbyname(socket.gethostname())
def get_open_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def set_cuda_visible_devices(device_ids: List[int]) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))

View File

@ -1,5 +1,5 @@
import time import time
from typing import Dict, List, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -8,6 +8,8 @@ import torch.nn as nn
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
broadcast, broadcast_object_list)
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl from vllm.utils import in_wsl
@ -28,10 +30,12 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
is_driver_worker: bool = False,
): ):
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.is_driver_worker = is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py. # model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
@ -70,7 +74,7 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = [] input_tokens: List[List[int]] = []
input_positions: List[List[int]] = [] input_positions: List[List[int]] = []
@ -135,14 +139,14 @@ class ModelRunner:
dtype=torch.long) dtype=torch.long)
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=prompt_lens, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
max_context_len=None, max_context_len=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
use_cuda_graph=False, use_cuda_graph=False,
) )
return input_tokens, input_positions, input_metadata return input_tokens, input_positions, input_metadata, prompt_lens
def _prepare_decode( def _prepare_decode(
self, self,
@ -203,32 +207,24 @@ class ModelRunner:
block_tables.append([]) block_tables.append([])
batch_size = graph_batch_size batch_size = graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
pin_memory = use_captured_graph and not self.in_wsl
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device=device, device="cuda")
pin_memory=pin_memory)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device=device, device="cuda")
pin_memory=pin_memory)
if use_captured_graph: if use_captured_graph:
# The shape of graph_block_tables is # The shape of graph_block_tables is
@ -237,17 +233,18 @@ class ModelRunner:
for i, block_table in enumerate(block_tables): for i, block_table in enumerate(block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device) block_tables = torch.tensor(input_block_tables, device="cuda")
else: else:
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
block_tables, block_tables,
max_len=max_context_len, max_len=max_context_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda",
) )
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=[], is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
max_context_len=max_context_len, max_context_len=max_context_len,
context_lens=context_lens, context_lens=context_lens,
@ -326,23 +323,127 @@ class ModelRunner:
) )
return sampling_metadata return sampling_metadata
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata,
prompt_lens) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata
) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
prompt_lens)
def get_size_or_none(x: Optional[torch.Tensor]):
return x.size() if x is not None else None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data = {
"input_tokens_size":
input_tokens.size(),
"input_positions_size":
input_positions.size(),
"is_prompt":
input_metadata.is_prompt,
"slot_mapping_size":
get_size_or_none(input_metadata.slot_mapping),
"max_context_len":
input_metadata.max_context_len,
"context_lens_size":
get_size_or_none(input_metadata.context_lens),
"block_tables_size":
get_size_or_none(input_metadata.block_tables),
"use_cuda_graph":
input_metadata.use_cuda_graph,
"selected_token_indices_size":
sampling_metadata.selected_token_indices.size(),
}
broadcast_object_list([py_data], src=0)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast(input_tokens, src=0)
broadcast(input_positions, src=0)
if input_metadata.slot_mapping is not None:
broadcast(input_metadata.slot_mapping, src=0)
if input_metadata.context_lens is not None:
broadcast(input_metadata.context_lens, src=0)
if input_metadata.block_tables is not None:
broadcast(input_metadata.block_tables, src=0)
broadcast(sampling_metadata.selected_token_indices, src=0)
else:
receving_list = [None]
broadcast_object_list(receving_list, src=0)
py_data = receving_list[0]
input_tokens = torch.empty(*py_data["input_tokens_size"],
dtype=torch.long,
device="cuda")
broadcast(input_tokens, src=0)
input_positions = torch.empty(*py_data["input_positions_size"],
dtype=torch.long,
device="cuda")
broadcast(input_positions, src=0)
if py_data["slot_mapping_size"] is not None:
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
dtype=torch.long,
device="cuda")
broadcast(slot_mapping, src=0)
else:
slot_mapping = None
if py_data["context_lens_size"] is not None:
context_lens = torch.empty(*py_data["context_lens_size"],
dtype=torch.int,
device="cuda")
broadcast(context_lens, src=0)
else:
context_lens = None
if py_data["block_tables_size"] is not None:
block_tables = torch.empty(*py_data["block_tables_size"],
dtype=torch.int,
device="cuda")
broadcast(block_tables, src=0)
else:
block_tables = None
selected_token_indices = torch.empty(
*py_data["selected_token_indices_size"],
dtype=torch.long,
device="cuda")
broadcast(selected_token_indices, src=0)
input_metadata = InputMetadata(
is_prompt=py_data["is_prompt"],
slot_mapping=slot_mapping,
max_context_len=py_data["max_context_len"],
context_lens=context_lens,
block_tables=block_tables,
use_cuda_graph=py_data["use_cuda_graph"],
)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
perform_sampling=False,
)
return input_tokens, input_positions, input_metadata, sampling_metadata
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
# NOTE: We assume that all sequences in the group are all prompts or input_tokens, input_positions, input_metadata, sampling_metadata = (
# all decodes. self.prepare_input_tensors(seq_group_metadata_list))
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
@ -356,9 +457,6 @@ class ModelRunner:
input_metadata=input_metadata, input_metadata=input_metadata,
) )
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -424,7 +522,7 @@ class ModelRunner:
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE):
# Create dummy input_metadata. # Create dummy input_metadata.
input_metadata = InputMetadata( input_metadata = InputMetadata(
prompt_lens=[], is_prompt=False,
slot_mapping=slot_mapping[:batch_size], slot_mapping=slot_mapping[:batch_size],
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],

View File

@ -8,6 +8,8 @@ import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_object_list)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel) initialize_model_parallel)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceGroupMetadata
@ -28,17 +30,23 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
rank: Optional[int] = None, local_rank: int,
distributed_init_method: Optional[str] = None, rank: int,
distributed_init_method: str,
is_driver_worker: bool = False,
) -> None: ) -> None:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config, parallel_config, self.model_runner = ModelRunner(model_config, parallel_config,
scheduler_config) scheduler_config, is_driver_worker)
# Uninitialized cache engine. Will be initialized by # Uninitialized cache engine. Will be initialized by
# self.init_cache_engine(). # self.init_cache_engine().
self.cache_config = None self.cache_config = None
@ -57,13 +65,7 @@ class Worker:
# This env var set by Ray causes exceptions with graph building. # This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray. self.device = torch.device(f"cuda:{self.local_rank}")
self.rank = self.rank if self.rank is not None else int(
os.getenv("RANK", "-1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.device = torch.device(f"cuda:{local_rank}")
if self.rank < 0:
raise ValueError("Invalid or unspecified rank.")
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
@ -125,14 +127,12 @@ class Worker:
# the model initialization and profiling. # the model initialization and profiling.
set_random_seed(self.model_config.seed) set_random_seed(self.model_config.seed)
@torch.inference_mode() def cache_swap(
def execute_model(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]], blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput: ) -> None:
# Issue cache operations. # Issue cache operations.
issued_cache_op = False issued_cache_op = False
if blocks_to_swap_in: if blocks_to_swap_in:
@ -152,8 +152,38 @@ class Worker:
if cache_events is not None: if cache_events is not None:
for event in cache_events: for event in cache_events:
event.wait() event.wait()
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
block_swapping_info = [
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
]
broadcast_object_list([num_seq_groups] + block_swapping_info,
src=0)
else:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data = [None] * 4
broadcast_object_list(recv_data, src=0)
num_seq_groups = recv_data[0]
block_swapping_info = recv_data[1:]
self.cache_swap(*block_swapping_info)
# If there is no input, we don't need to execute the model. # If there is no input, we don't need to execute the model.
if not seq_group_metadata_list: if num_seq_groups == 0:
return {} return {}
output = self.model_runner.execute_model(seq_group_metadata_list, output = self.model_runner.execute_model(seq_group_metadata_list,