ffn server use vllm serve and dp

Signed-off-by: jiangkuaixue123 <jiangxiaozhou111@163.com>
This commit is contained in:
jiangkuaixue123 2025-12-13 10:36:03 +08:00
parent 28cba040c7
commit eb2355c600
6 changed files with 56 additions and 5 deletions

View File

@ -557,7 +557,6 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
from vllm.v1.executor import ray_utils from vllm.v1.executor import ray_utils
backend: DistributedExecutorBackend = "mp" backend: DistributedExecutorBackend = "mp"

View File

@ -191,7 +191,6 @@ def run_multi_api_server(args: argparse.Namespace):
assert external_dp_lb or hybrid_dp_lb or dp_rank == 0 assert external_dp_lb or hybrid_dp_lb or dp_rank == 0
api_server_manager: APIServerProcessManager | None = None api_server_manager: APIServerProcessManager | None = None
with launch_core_engines( with launch_core_engines(
vllm_config, executor_class, log_stats, num_api_servers vllm_config, executor_class, log_stats, num_api_servers
) as (local_engine_manager, coordinator, addresses): ) as (local_engine_manager, coordinator, addresses):

View File

@ -1402,7 +1402,6 @@ async def run_server_worker(
listen_address, sock, args, client_config=None, **uvicorn_kwargs listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None: ) -> None:
"""Run a single API server worker.""" """Run a single API server worker."""
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
ToolParserManager.import_tool_parser(args.tool_parser_plugin) ToolParserManager.import_tool_parser(args.tool_parser_plugin)

View File

@ -103,6 +103,11 @@ class EngineCore:
if executor_fail_callback is not None: if executor_fail_callback is not None:
self.model_executor.register_failure_callback(executor_fail_callback) self.model_executor.register_failure_callback(executor_fail_callback)
self.afd_config = vllm_config.afd_config
if self.afd_config and self.afd_config.afd_role == "ffn":
logger.info("jcz EngineCore ffn role")
return
self.available_gpu_memory_for_kv_cache = -1 self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling. # Setup KV Caches and update CacheConfig after profiling.
@ -601,6 +606,7 @@ class EngineCoreProc(EngineCore):
executor_fail_callback = lambda: self.input_queue.put_nowait( executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b"") (EngineCoreRequestType.EXECUTOR_FAILED, b"")
) )
self.afd_config = vllm_config.afd_config
self.engine_index = engine_index self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little") identity = self.engine_index.to_bytes(length=2, byteorder="little")
@ -855,7 +861,6 @@ class EngineCoreProc(EngineCore):
set_process_title("EngineCore") set_process_title("EngineCore")
decorate_logs() decorate_logs()
engine_core = EngineCoreProc(*args, **kwargs) engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop() engine_core.run_busy_loop()
except SystemExit: except SystemExit:
@ -878,6 +883,23 @@ class EngineCoreProc(EngineCore):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore.""" """Core busy loop of the EngineCore."""
if self.afd_config and self.afd_config.afd_role == "ffn":
logger.info("AFD FFN Server started, workers running...")
try:
# Tell workers to start FFN server loops (one-time call)
self.model_executor.collective_rpc("start_ffn_server_loop")
# Main thread waits without busy polling
shutdown_event = threading.Event()
shutdown_event.wait() # Block until interrupted
except KeyboardInterrupt:
logger.info("Server shutting down...")
self.model_executor.collective_rpc("stop_ffn_server_loop")
except Exception as e:
logger.error("Server error: %s", e)
raise
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:
# 1) Poll the input queue until there is work to do. # 1) Poll the input queue until there is work to do.
@ -1156,6 +1178,7 @@ class DPEngineCoreProc(EngineCoreProc):
# Initialize the engine. # Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank dp_rank = vllm_config.parallel_config.data_parallel_rank
self.afd_config = vllm_config.afd_config
super().__init__( super().__init__(
vllm_config, vllm_config,
local_client, local_client,
@ -1238,6 +1261,22 @@ class DPEngineCoreProc(EngineCoreProc):
def run_busy_loop(self): def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case.""" """Core busy loop of the EngineCore for data parallel case."""
if self.afd_config and self.afd_config.afd_role == "ffn":
logger.info("AFD FFN Server started, workers running...")
try:
# Tell workers to start FFN server loops (one-time call)
self.model_executor.collective_rpc("start_ffn_server_loop")
# Main thread waits without busy polling
shutdown_event = threading.Event()
shutdown_event.wait() # Block until interrupted
except KeyboardInterrupt:
logger.info("Server shutting down...")
self.model_executor.collective_rpc("stop_ffn_server_loop")
except Exception as e:
logger.error("Server error: %s", e)
raise
# Loop until process is sent a SIGINT or SIGTERM # Loop until process is sent a SIGINT or SIGTERM
while True: while True:

View File

@ -16,7 +16,7 @@ import msgspec
import zmq import zmq
from vllm import envs from vllm import envs
from vllm.config import CacheConfig, ParallelConfig, VllmConfig from vllm.config import AFDConfig, CacheConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.ray.ray_env import get_env_vars_to_copy from vllm.ray.ray_env import get_env_vars_to_copy
@ -908,6 +908,7 @@ def launch_core_engines(
vllm_config.cache_config, vllm_config.cache_config,
local_engine_manager, local_engine_manager,
coordinator.proc if coordinator else None, coordinator.proc if coordinator else None,
vllm_config.afd_config,
) )
@ -919,6 +920,7 @@ def wait_for_engine_startup(
cache_config: CacheConfig, cache_config: CacheConfig,
proc_manager: CoreEngineProcManager | None, proc_manager: CoreEngineProcManager | None,
coord_process: Process | None, coord_process: Process | None,
afd_config: AFDConfig | None = None,
): ):
# Wait for engine core process(es) to send ready messages. # Wait for engine core process(es) to send ready messages.
local_count = parallel_config.data_parallel_size_local local_count = parallel_config.data_parallel_size_local
@ -1020,6 +1022,13 @@ def wait_for_engine_startup(
conn_pending[0 if local else 1] -= 1 conn_pending[0 if local else 1] -= 1
start_pending[0 if local else 1] += 1 start_pending[0 if local else 1] += 1
engine.state = CoreEngineState.CONNECTED engine.state = CoreEngineState.CONNECTED
elif (
status == "READY"
and engine.state == CoreEngineState.CONNECTED
and afd_config
and afd_config.afd_role == "ffn"
):
engine.state = CoreEngineState.READY
elif status == "READY" and engine.state == CoreEngineState.CONNECTED: elif status == "READY" and engine.state == CoreEngineState.CONNECTED:
# Setup KV cache config with initialization state from # Setup KV cache config with initialization state from
# engine core process. Sum values from all engines in DP case. # engine core process. Sum values from all engines in DP case.

View File

@ -213,6 +213,9 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank() self.output_rank = self._get_output_rank()
self.afd_config = self.vllm_config.afd_config
self.afd_role = self.afd_config.afd_role if self.afd_config else None
def start_worker_monitor(self, inline=False) -> None: def start_worker_monitor(self, inline=False) -> None:
workers = self.workers workers = self.workers
self_ref = weakref.ref(self) self_ref = weakref.ref(self)
@ -565,6 +568,9 @@ class WorkerProc:
# environment variable overrides after this point) # environment variable overrides after this point)
enable_envs_cache() enable_envs_cache()
self.afd_config = vllm_config.afd_config
self.afd_role = self.afd_config.afd_role if self.afd_config else None
@staticmethod @staticmethod
def make_worker_process( def make_worker_process(
vllm_config: VllmConfig, vllm_config: VllmConfig,