Signed-off-by: Robert Shaw <robshaw@redhat.com>
This commit is contained in:
Robert Shaw 2025-07-20 02:15:19 +00:00
parent 14f13ed690
commit b90d33163c
6 changed files with 43 additions and 16 deletions

View File

@ -1091,13 +1091,15 @@ class EngineArgs:
# but we should not do this here. # but we should not do this here.
placement_group = ray.util.get_current_placement_group() placement_group = ray.util.get_current_placement_group()
data_parallel_external_lb = self.data_parallel_rank is not None # data_parallel_external_lb = self.data_parallel_rank is not None
if data_parallel_external_lb: # if data_parallel_external_lb:
assert self.data_parallel_size_local in (1, None), ( # assert self.data_parallel_size_local in (1, None), (
"data_parallel_size_local must be 1 when data_parallel_rank " # "data_parallel_size_local must be 1 when data_parallel_rank "
"is set") # "is set")
data_parallel_size_local = 1 # data_parallel_size_local = 1
elif self.data_parallel_size_local is not None: # elif self.data_parallel_size_local is not None:
data_parallel_external_lb = False
if self.data_parallel_size_local is not None:
data_parallel_size_local = self.data_parallel_size_local data_parallel_size_local = self.data_parallel_size_local
else: else:
# Local DP size defaults to global DP size if not set. # Local DP size defaults to global DP size if not set.

View File

@ -45,11 +45,11 @@ class ServeSubcommand(CLISubcommand):
if args.headless or args.api_server_count < 1: if args.headless or args.api_server_count < 1:
run_headless(args) run_headless(args)
else: else:
if args.data_parallel_start_rank: # if args.data_parallel_start_rank:
raise ValueError( # raise ValueError(
"data_parallel_start_rank is only applicable " # "data_parallel_start_rank is only applicable "
"in headless mode. " # "in headless mode. "
"Add --headless flag to enable headless mode.") # "Add --headless flag to enable headless mode.")
if args.api_server_count > 1: if args.api_server_count > 1:
run_multi_api_server(args) run_multi_api_server(args)
else: else:

View File

@ -303,7 +303,7 @@ def download_weights_from_hf(
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=DisabledTqdm, # tqdm_class=DisabledTqdm,
revision=revision, revision=revision,
local_files_only=local_only, local_files_only=local_only,
) )

View File

@ -411,10 +411,12 @@ class EngineCoreProc(EngineCore):
identity = self.engine_index.to_bytes(length=2, byteorder="little") identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False self.engines_running = False
logger.info("======= HANDSHAKING:")
with self._perform_handshakes(handshake_address, identity, with self._perform_handshakes(handshake_address, identity,
local_client, vllm_config, local_client, vllm_config,
client_handshake_address) as addresses: client_handshake_address) as addresses:
self.client_count = len(addresses.outputs) self.client_count = len(addresses.outputs)
logger.info(f"{addresses.outputs=}")
# Set up data parallel environment. # Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None self.has_coordinator = addresses.coordinator_output is not None
@ -482,16 +484,21 @@ class EngineCoreProc(EngineCore):
""" """
input_ctx = zmq.Context() input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None is_local = local_client and client_handshake_address is None
logger.info(f"HS: {handshake_address=}, {is_local=}")
handshake = self._perform_handshake(input_ctx, handshake_address, handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, vllm_config, identity, is_local, vllm_config,
vllm_config.parallel_config) vllm_config.parallel_config)
logger.info(f"DONE HS: {handshake=}")
if client_handshake_address is None: if client_handshake_address is None:
with handshake as addresses: with handshake as addresses:
yield addresses yield addresses
else: else:
logger.info(f"HS: {client_handshake_address=}, {local_client=}")
local_handshake = self._perform_handshake( local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, local_client, input_ctx, client_handshake_address, identity, local_client,
vllm_config) vllm_config)
logger.info(f"DONE HS: {local_handshake=}")
with handshake as addresses, local_handshake as client_addresses: with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs addresses.outputs = client_addresses.outputs
@ -517,6 +524,8 @@ class EngineCoreProc(EngineCore):
linger=5000, linger=5000,
bind=False) as handshake_socket: bind=False) as handshake_socket:
# Register engine with front-end. # Register engine with front-end.
logger.info(f"calling startup_handshake: {handshake_address=}")
logger.info(f"calling startup_handshake: {local_client=}")
addresses = self.startup_handshake(handshake_socket, local_client, addresses = self.startup_handshake(handshake_socket, local_client,
parallel_config_to_update) parallel_config_to_update)
yield addresses yield addresses

View File

@ -405,12 +405,15 @@ class MPClient(EngineCoreClient):
"stats_update_address") "stats_update_address")
else: else:
# Engines are managed by this client. # Engines are managed by this client.
print(f"{vllm_config.parallel_config=}")
with launch_core_engines(vllm_config, executor_class, with launch_core_engines(vllm_config, executor_class,
log_stats) as (engine_manager, log_stats) as (engine_manager,
coordinator, coordinator,
addresses): addresses):
self.resources.coordinator = coordinator self.resources.coordinator = coordinator
self.resources.engine_manager = engine_manager self.resources.engine_manager = engine_manager
print("========================================")
print(f"{vllm_config.parallel_config=}")
(input_address, ) = addresses.inputs (input_address, ) = addresses.inputs
(output_address, ) = addresses.outputs (output_address, ) = addresses.outputs

View File

@ -555,6 +555,8 @@ def launch_core_engines(
# sends requests only to colocated engines. # sends requests only to colocated engines.
client_local_only = offline_mode or external_dp_lb or (local_engine_count client_local_only = offline_mode or external_dp_lb or (local_engine_count
== dp_size) == dp_size)
# HACK: handle case with one pod per node.
client_local_only = True
# Set up input and output addresses. # Set up input and output addresses.
addresses = EngineZmqAddresses( addresses = EngineZmqAddresses(
@ -601,11 +603,17 @@ def launch_core_engines(
if offline_mode or (external_dp_lb and dp_rank > 0): if offline_mode or (external_dp_lb and dp_rank > 0):
assert local_engine_count == 1 assert local_engine_count == 1
engines_to_handshake = [CoreEngine(index=dp_rank, local=True)] engines_to_handshake = [CoreEngine(index=dp_rank, local=True)]
else: elif dp_rank == 0:
engines_to_handshake = [ engines_to_handshake = [
CoreEngine(index=i, local=(i < local_engine_count)) CoreEngine(index=i, local=(i < local_engine_count))
for i in range(dp_size) for i in range(dp_size)
] ]
else:
# Just handshake with local engines.
engines_to_handshake = [
CoreEngine(index=i, local=True) for i in
range(dp_rank, dp_rank + local_engine_count)
]
# Whether the started engines will handshake only with co-located # Whether the started engines will handshake only with co-located
# front-end processes. In external_dp_lb mode, ranks > 0 handshake with # front-end processes. In external_dp_lb mode, ranks > 0 handshake with
@ -616,7 +624,8 @@ def launch_core_engines(
handshake_address = get_engine_client_zmq_addr( handshake_address = get_engine_client_zmq_addr(
handshake_local_only, host, parallel_config.data_parallel_rpc_port) handshake_local_only, host, parallel_config.data_parallel_rpc_port)
if external_dp_lb and dp_rank > 0: # if external_dp_lb and dp_rank > 0:
if dp_rank > 0:
assert not handshake_local_only assert not handshake_local_only
local_handshake_address = get_open_zmq_ipc_path() local_handshake_address = get_open_zmq_ipc_path()
client_handshake_address = local_handshake_address client_handshake_address = local_handshake_address
@ -624,15 +633,18 @@ def launch_core_engines(
local_handshake_address = handshake_address local_handshake_address = handshake_address
client_handshake_address = None client_handshake_address = None
print(f"{local_handshake_address=}")
with zmq_socket_ctx(local_handshake_address, zmq.ROUTER, with zmq_socket_ctx(local_handshake_address, zmq.ROUTER,
bind=True) as handshake_socket: bind=True) as handshake_socket:
from vllm.v1.engine.core import EngineCoreProc from vllm.v1.engine.core import EngineCoreProc
print(f"{client_handshake_address=}")
print(f"{handshake_address=}")
# Start local engines. # Start local engines.
if local_engine_count: if local_engine_count:
# In server mode, start_index and local_start_index will # In server mode, start_index and local_start_index will
# both be 0. # both be 0. << todo: update
local_engine_manager = CoreEngineProcManager( local_engine_manager = CoreEngineProcManager(
EngineCoreProc.run_engine_core, EngineCoreProc.run_engine_core,
vllm_config=vllm_config, vllm_config=vllm_config,
@ -650,6 +662,7 @@ def launch_core_engines(
yield local_engine_manager, coordinator, addresses yield local_engine_manager, coordinator, addresses
# Now wait for engines to start. # Now wait for engines to start.
print(f"{engines_to_handshake=}")
wait_for_engine_startup( wait_for_engine_startup(
handshake_socket, handshake_socket,
addresses, addresses,