diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 5bc22f36cbb6f..a411a465b0d57 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -299,6 +299,7 @@ class EngineArgs: data_parallel_size_local: Optional[int] = None data_parallel_address: Optional[str] = None data_parallel_rpc_port: Optional[int] = None + data_parallel_enable_hybrid_lb: Optional[int] = None data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_eplb: bool = ParallelConfig.enable_eplb @@ -633,6 +634,9 @@ class EngineArgs: default='mp', help='Backend for data parallel, either ' '"mp" or "ray".') + parallel_group.add_argument( + "--data-parallel-hybrid-lb", + **parallel_kwargs["data_parallel_hybrid_lb"]) parallel_group.add_argument( "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]) @@ -1097,7 +1101,6 @@ class EngineArgs: # but we should not do this here. placement_group = ray.util.get_current_placement_group() - DATA_PARALLEL_HYBRID_LB = True data_parallel_external_lb = self.data_parallel_rank is not None if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( @@ -1105,8 +1108,8 @@ class EngineArgs: "is set") data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. - DATA_PARALLEL_HYBRID_LB = False - elif DATA_PARALLEL_HYBRID_LB: + self.data_parallel_hybrid_lb = False + elif self.data_parallel_hybrid_lb: assert self.data_parallel_start_rank is not None, ( "data_parallel_start_rank must be set to use " "data_parallel_hybrid_lb.") @@ -1116,7 +1119,7 @@ class EngineArgs: # Use full external lb if we have local_size of 1. if self.data_parallel_size_local == 1: data_parallel_external_lb = True - DATA_PARALLEL_HYBRID_LB = False + self.data_parallel_hybrid_lb = False data_parallel_size_local = self.data_parallel_size_local self.data_parallel_rank = self.data_parallel_start_rank elif self.data_parallel_size_local is not None: @@ -1178,7 +1181,7 @@ class EngineArgs: data_parallel_master_ip=data_parallel_address, data_parallel_rpc_port=data_parallel_rpc_port, data_parallel_backend=self.data_parallel_backend, - data_parallel_hybrid_lb=DATA_PARALLEL_HYBRID_LB, + data_parallel_hybrid_lb=self.data_parallel_hybrid_lb, enable_expert_parallel=self.enable_expert_parallel, enable_eplb=self.enable_eplb, num_redundant_experts=self.num_redundant_experts,