mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:34:57 +08:00
Simplify broadcast logic for control messages (#2501)
This commit is contained in:
parent
2709c0009a
commit
ef9b636e2d
@ -11,6 +11,7 @@ from vllm.utils import get_open_port
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather,
|
||||
broadcast_tensor_dict,
|
||||
)
|
||||
from vllm.worker.worker import _init_distributed_environment
|
||||
|
||||
@ -64,11 +65,41 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
broadcast_tensor_dict(test_dict, src=0)
|
||||
else:
|
||||
recv_dict = broadcast_tensor_dict(src=0)
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker])
|
||||
@pytest.mark.parametrize("test_target", [
|
||||
all_reduce_test_worker, all_gather_test_worker,
|
||||
broadcast_tensor_dict_test_worker
|
||||
])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
# Using ray helps debugging the error when it failed
|
||||
# as compared to multiprocessing.
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
@ -7,7 +10,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_):
|
||||
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
|
||||
"""All-reduce the input tensor across model parallel group.
|
||||
|
||||
NOTE: This operation is applied in-place on the input tensor.
|
||||
@ -21,7 +24,8 @@ def tensor_model_parallel_all_reduce(input_):
|
||||
return input_
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(input_, dim=-1):
|
||||
def tensor_model_parallel_all_gather(input_: torch.Tensor,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
@ -48,7 +52,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
|
||||
return output_tensor
|
||||
|
||||
|
||||
def tensor_model_parallel_gather(input_, dst=0, dim=-1):
|
||||
def tensor_model_parallel_gather(input_: torch.Tensor,
|
||||
dst: int = 0,
|
||||
dim: int = -1) -> torch.Tensor:
|
||||
"""Gather the input tensor across model parallel group.
|
||||
|
||||
NOTE: We assume that the input tensor is on the same device across
|
||||
@ -80,7 +86,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
|
||||
return output_tensor
|
||||
|
||||
|
||||
def broadcast(input_, src=0):
|
||||
def broadcast(input_: torch.Tensor, src: int = 0):
|
||||
"""Broadcast the input tensor."""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
@ -93,7 +99,7 @@ def broadcast(input_, src=0):
|
||||
return input_
|
||||
|
||||
|
||||
def broadcast_object_list(obj_list, src=0):
|
||||
def broadcast_object_list(obj_list: List[Any], src: int = 0):
|
||||
"""Broadcast the input object list."""
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
@ -104,3 +110,60 @@ def broadcast_object_list(obj_list, src=0):
|
||||
# Broadcast.
|
||||
torch.distributed.broadcast_object_list(obj_list, src=src)
|
||||
return obj_list
|
||||
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||
|
||||
|
||||
def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
|
||||
Any]]] = None,
|
||||
src: int = 0) -> Dict[Any, Union[torch.Tensor, Any]]:
|
||||
"""Broadcast the input tensor dictionary."""
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
assert 0 <= src < world_size, f"Invalid src rank ({src})"
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return tensor_dict
|
||||
|
||||
if rank == src:
|
||||
assert isinstance(
|
||||
tensor_dict,
|
||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||
metadata_list = []
|
||||
for key, value in tensor_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
assert value.is_cuda, (
|
||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||
f"support broadcasting tensors on cuda.")
|
||||
metadata_list.append(
|
||||
(key, TensorMetadata(value.dtype, value.size())))
|
||||
else:
|
||||
metadata_list.append((key, value))
|
||||
torch.distributed.broadcast_object_list([metadata_list], src=src)
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = tensor_dict[key]
|
||||
torch.distributed.broadcast(tensor, src=src)
|
||||
else:
|
||||
recv_metadata_list = [None]
|
||||
torch.distributed.broadcast_object_list(recv_metadata_list, src=src)
|
||||
metadata_list = recv_metadata_list[0]
|
||||
tensor_dict = {}
|
||||
async_handles = []
|
||||
for key, value in metadata_list:
|
||||
if isinstance(value, TensorMetadata):
|
||||
tensor = torch.empty(value.size,
|
||||
dtype=value.dtype,
|
||||
device="cuda")
|
||||
async_handle = torch.distributed.broadcast(tensor,
|
||||
src=src,
|
||||
async_op=True)
|
||||
async_handles.append(async_handle)
|
||||
tensor_dict[key] = tensor
|
||||
else:
|
||||
tensor_dict[key] = value
|
||||
for async_handle in async_handles:
|
||||
async_handle.wait()
|
||||
return tensor_dict
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast, broadcast_object_list)
|
||||
broadcast_tensor_dict)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import in_wsl
|
||||
@ -393,121 +393,43 @@ class ModelRunner:
|
||||
prompt_lens,
|
||||
subquery_lens)
|
||||
|
||||
def get_size_or_none(x: Optional[torch.Tensor]):
|
||||
return x.size() if x is not None else None
|
||||
|
||||
# Broadcast the input data. For input tensors, we first broadcast
|
||||
# its shape and then broadcast the tensor to avoid high
|
||||
# serialization cost.
|
||||
py_data = {
|
||||
"input_tokens_size":
|
||||
input_tokens.size(),
|
||||
"input_positions_size":
|
||||
input_positions.size(),
|
||||
"is_prompt":
|
||||
input_metadata.is_prompt,
|
||||
"slot_mapping_size":
|
||||
get_size_or_none(input_metadata.slot_mapping),
|
||||
"prompt_lens_size":
|
||||
get_size_or_none(input_metadata.prompt_lens),
|
||||
"max_seq_len":
|
||||
input_metadata.max_seq_len,
|
||||
"start_loc_size":
|
||||
get_size_or_none(input_metadata.start_loc),
|
||||
"max_context_len":
|
||||
input_metadata.max_context_len,
|
||||
"context_lens_size":
|
||||
get_size_or_none(input_metadata.context_lens),
|
||||
"block_tables_size":
|
||||
get_size_or_none(input_metadata.block_tables),
|
||||
"use_cuda_graph":
|
||||
input_metadata.use_cuda_graph,
|
||||
"selected_token_indices_size":
|
||||
sampling_metadata.selected_token_indices.size(),
|
||||
# Broadcast the metadata.
|
||||
metadata_dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"input_positions": input_positions,
|
||||
"is_prompt": input_metadata.is_prompt,
|
||||
"slot_mapping": input_metadata.slot_mapping,
|
||||
"prompt_lens": input_metadata.prompt_lens,
|
||||
"max_seq_len": input_metadata.max_seq_len,
|
||||
"start_loc": input_metadata.start_loc,
|
||||
"max_context_len": input_metadata.max_context_len,
|
||||
"context_lens": input_metadata.context_lens,
|
||||
"block_tables": input_metadata.block_tables,
|
||||
"use_cuda_graph": input_metadata.use_cuda_graph,
|
||||
"selected_token_indices":
|
||||
sampling_metadata.selected_token_indices,
|
||||
}
|
||||
broadcast_object_list([py_data], src=0)
|
||||
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
|
||||
broadcast(input_tokens, src=0)
|
||||
broadcast(input_positions, src=0)
|
||||
if input_metadata.slot_mapping is not None:
|
||||
broadcast(input_metadata.slot_mapping, src=0)
|
||||
if input_metadata.prompt_lens is not None:
|
||||
broadcast(input_metadata.prompt_lens, src=0)
|
||||
if input_metadata.start_loc is not None:
|
||||
broadcast(input_metadata.start_loc, src=0)
|
||||
if input_metadata.context_lens is not None:
|
||||
broadcast(input_metadata.context_lens, src=0)
|
||||
if input_metadata.block_tables is not None:
|
||||
broadcast(input_metadata.block_tables, src=0)
|
||||
broadcast(sampling_metadata.selected_token_indices, src=0)
|
||||
broadcast_tensor_dict(metadata_dict, src=0)
|
||||
else:
|
||||
receving_list = [None]
|
||||
broadcast_object_list(receving_list, src=0)
|
||||
py_data = receving_list[0]
|
||||
input_tokens = torch.empty(*py_data["input_tokens_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(input_tokens, src=0)
|
||||
input_positions = torch.empty(*py_data["input_positions_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(input_positions, src=0)
|
||||
if py_data["slot_mapping_size"] is not None:
|
||||
slot_mapping = torch.empty(*py_data["slot_mapping_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(slot_mapping, src=0)
|
||||
else:
|
||||
slot_mapping = None
|
||||
if py_data["prompt_lens_size"] is not None:
|
||||
prompt_lens = torch.empty(*py_data["prompt_lens_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(prompt_lens, src=0)
|
||||
else:
|
||||
prompt_lens = None
|
||||
if py_data["start_loc_size"] is not None:
|
||||
start_loc = torch.empty(*py_data["start_loc_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(start_loc, src=0)
|
||||
else:
|
||||
start_loc = None
|
||||
if py_data["context_lens_size"] is not None:
|
||||
context_lens = torch.empty(*py_data["context_lens_size"],
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
broadcast(context_lens, src=0)
|
||||
else:
|
||||
context_lens = None
|
||||
if py_data["block_tables_size"] is not None:
|
||||
block_tables = torch.empty(*py_data["block_tables_size"],
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
broadcast(block_tables, src=0)
|
||||
else:
|
||||
block_tables = None
|
||||
selected_token_indices = torch.empty(
|
||||
*py_data["selected_token_indices_size"],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
broadcast(selected_token_indices, src=0)
|
||||
metadata_dict = broadcast_tensor_dict(src=0)
|
||||
input_tokens = metadata_dict["input_tokens"]
|
||||
input_positions = metadata_dict["input_positions"]
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=py_data["is_prompt"],
|
||||
slot_mapping=slot_mapping,
|
||||
prompt_lens=prompt_lens,
|
||||
max_seq_len=py_data["max_seq_len"],
|
||||
start_loc=start_loc,
|
||||
max_context_len=py_data["max_context_len"],
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=py_data["use_cuda_graph"],
|
||||
is_prompt=metadata_dict["is_prompt"],
|
||||
slot_mapping=metadata_dict["slot_mapping"],
|
||||
prompt_lens=metadata_dict["prompt_lens"],
|
||||
max_seq_len=metadata_dict["max_seq_len"],
|
||||
start_loc=metadata_dict["start_loc"],
|
||||
max_context_len=metadata_dict["max_context_len"],
|
||||
context_lens=metadata_dict["context_lens"],
|
||||
block_tables=metadata_dict["block_tables"],
|
||||
use_cuda_graph=metadata_dict["use_cuda_graph"],
|
||||
)
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=None,
|
||||
seq_data=None,
|
||||
prompt_lens=None,
|
||||
selected_token_indices=selected_token_indices,
|
||||
selected_token_indices=metadata_dict["selected_token_indices"],
|
||||
categorized_sample_indices=None,
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_object_list)
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
@ -175,20 +175,21 @@ class Worker:
|
||||
assert blocks_to_swap_in is not None
|
||||
assert blocks_to_swap_out is not None
|
||||
assert blocks_to_copy is not None
|
||||
block_swapping_info = [
|
||||
blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy
|
||||
]
|
||||
broadcast_object_list([num_seq_groups] + block_swapping_info,
|
||||
src=0)
|
||||
data = {
|
||||
"num_seq_groups": num_seq_groups,
|
||||
"blocks_to_swap_in": blocks_to_swap_in,
|
||||
"blocks_to_swap_out": blocks_to_swap_out,
|
||||
"blocks_to_copy": blocks_to_copy,
|
||||
}
|
||||
broadcast_tensor_dict(data, src=0)
|
||||
else:
|
||||
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
|
||||
# blocks_to_copy (4 elements)
|
||||
recv_data = [None] * 4
|
||||
broadcast_object_list(recv_data, src=0)
|
||||
num_seq_groups = recv_data[0]
|
||||
block_swapping_info = recv_data[1:]
|
||||
data = broadcast_tensor_dict(src=0)
|
||||
num_seq_groups = data["num_seq_groups"]
|
||||
blocks_to_swap_in = data["blocks_to_swap_in"]
|
||||
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||
blocks_to_copy = data["blocks_to_copy"]
|
||||
|
||||
self.cache_swap(*block_swapping_info)
|
||||
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if num_seq_groups == 0:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user