mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 22:05:44 +08:00
[Benchmark] Support benchmark throughput for external launcher DP (#25913)
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
89e4050af4
commit
d3bd171123
@ -358,7 +358,23 @@ def get_requests(args, tokenizer):
|
||||
raise ValueError(f"Unknown dataset name: {args.dataset_name}")
|
||||
# Remove None values
|
||||
sample_kwargs = {k: v for k, v in sample_kwargs.items() if v is not None}
|
||||
return dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = dataset_cls(**common_kwargs).sample(**sample_kwargs)
|
||||
requests = filter_requests_for_dp(requests, args.data_parallel_size)
|
||||
return requests
|
||||
|
||||
|
||||
def filter_requests_for_dp(requests, data_parallel_size):
|
||||
# Note(zhuohan): The way we get data_parallel_rank is hacky and only
|
||||
# works for external launcher mode. Should be cleaned up and deprecated
|
||||
# in the future with a better vLLM distributed process design.
|
||||
if data_parallel_size == 1:
|
||||
return requests
|
||||
|
||||
global_rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
data_parallel_rank = global_rank // (world_size // data_parallel_size)
|
||||
return [r for i, r in enumerate(requests)
|
||||
if i % data_parallel_size == data_parallel_rank]
|
||||
|
||||
|
||||
def validate_args(args):
|
||||
@ -453,12 +469,17 @@ def validate_args(args):
|
||||
if args.backend == "mii" and args.tokenizer != args.model:
|
||||
raise ValueError(
|
||||
"Tokenizer must be the same as the model for MII backend.")
|
||||
|
||||
# --data-parallel is not supported currently.
|
||||
# https://github.com/vllm-project/vllm/issues/16222
|
||||
if args.data_parallel_size > 1:
|
||||
|
||||
if args.data_parallel_size > 1 and (
|
||||
args.distributed_executor_backend != "external_launcher"
|
||||
or args.async_engine):
|
||||
# --data-parallel is not supported fully.
|
||||
# Old issue: https://github.com/vllm-project/vllm/issues/16222
|
||||
# Currently we only support data parallel with external launcher
|
||||
# mode (i.e., launch with toruchrun).
|
||||
raise ValueError(
|
||||
"Data parallel is not supported in offline benchmark, "
|
||||
"Data parallel is only supported with external launcher mode "
|
||||
"with synchronous engine in offline benchmark, "
|
||||
"please use benchmark serving instead"
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user