Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-27 07:23:48 +00:00
parent b3b195a540
commit ad5678b056

View File

@ -240,7 +240,9 @@ class MoRIIOConfig:
# TODO : merge notify_port and handshake_port to simplify port management # TODO : merge notify_port and handshake_port to simplify port management
# supports non-contiguous ports # supports non-contiguous ports
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
kv_transfer_config = vllm_config.kv_transfer_config kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config extra_config = kv_transfer_config.kv_connector_extra_config
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
@ -745,12 +747,12 @@ class MoRIIOWrapper:
else: else:
self.done_write_cache_req_ids.append(msg) self.done_write_cache_req_ids.append(msg)
def send_notify(self, req_ids, remote_ip=None, remote_port=None): def send_notify(self, req_ids, remote_ip, remote_port):
if not remote_ip or not remote_port: if not remote_ip or not remote_port:
logger.warning("Missing remote_ip or remote_port for notification") logger.warning("Missing remote_ip or remote_port for notification")
return return
path = make_zmq_path("tcp", remote_ip, str(remote_port)) path = make_zmq_path("tcp", remote_ip, remote_port)
if path not in self.paths: if path not in self.paths:
ctx = zmq.Context.instance() ctx = zmq.Context.instance()
@ -872,18 +874,18 @@ class MoRIIOConnector(KVConnectorBase_V1):
kv_cache_config: Optional["KVCacheConfig"] = None, kv_cache_config: Optional["KVCacheConfig"] = None,
): ):
super().__init__(vllm_config, role) super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config
# assert vllm_config.kv_transfer_config.engine_id is not None # assert vllm_config.kv_transfer_config.engine_id is not None
self._set_port_defaults(vllm_config) self._set_port_defaults(vllm_config)
self.engine_id = ( self.engine_id = (
str(get_ip()) str(get_ip())
+ ":" + ":"
+ str( + str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"])
vllm_config.kv_transfer_config.kv_connector_extra_config[
"handshake_port"
]
)
) )
self.mode = get_moriio_mode() self.mode = get_moriio_mode()
if role == KVConnectorRole.SCHEDULER: if role == KVConnectorRole.SCHEDULER:
@ -905,6 +907,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
############################################################ ############################################################
def _set_port_defaults(self, vllm_config: VllmConfig): def _set_port_defaults(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
kv_transfer_config = vllm_config.kv_transfer_config kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config extra_config = kv_transfer_config.kv_connector_extra_config
@ -1011,23 +1016,26 @@ class MoRIIOConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str): def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id self.engine_id: EngineId = engine_id
self.mode = get_moriio_mode() self.mode = get_moriio_mode()
self.host_ip = get_ip() self.host_ip = get_ip()
self.handshake_port = ( self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[
self.vllm_config.kv_transfer_config.kv_connector_extra_config[ "handshake_port"
"handshake_port" ]
]
)
logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id) logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id)
self.side_notify_port = ( self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] "notify_port"
) ]
self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" self.is_producer = self.kv_transfer_config.kv_role == "kv_producer"
# Requests that need to start recv/send. # Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in # New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker. # the scheduler. Used to make metadata passed to Worker.
@ -1045,7 +1053,6 @@ class MoRIIOConnectorScheduler:
# Reqs to send and their expiration time # Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self.sock = None self.sock = None
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
self.paths: dict[str, zmq.Socket] = {} self.paths: dict[str, zmq.Socket] = {}
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
@ -1070,12 +1077,13 @@ class MoRIIOConnectorScheduler:
if self.is_producer: if self.is_producer:
return 0, False return 0, False
token_ids = request.prompt_token_ids or []
if self.mode == MoRIIOMode.WRITE: if self.mode == MoRIIOMode.WRITE:
# MoriiO in write mode, no remote prefill # MoriiO in write mode, no remote prefill
return len(request.prompt_token_ids) - num_computed_tokens, True return len(token_ids) - num_computed_tokens, True
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False return len(token_ids) - 1 - num_computed_tokens, False
def send_notify_block( def send_notify_block(
self, req_id: str, block_notify_list: list[int], host=None, port=None self, req_id: str, block_notify_list: list[int], host=None, port=None
@ -1105,6 +1113,8 @@ class MoRIIOConnectorScheduler:
connector_worker: Optional["MoRIIOConnectorWorker"] = None, connector_worker: Optional["MoRIIOConnectorWorker"] = None,
): ):
params = request.kv_transfer_params params = request.kv_transfer_params
if not params:
return
if params.get("do_remote_decode"): if params.get("do_remote_decode"):
local_block_ids = blocks.get_block_ids()[0] local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids) self._reqs_need_save[request.request_id] = (request, local_block_ids)
@ -1140,6 +1150,10 @@ class MoRIIOConnectorScheduler:
) )
else: else:
assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None"
)
remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)
for tp_index in range(self.tp_size): for tp_index in range(self.tp_size):
@ -1178,9 +1192,11 @@ class MoRIIOConnectorScheduler:
assert hasattr(new_req.sampling_params, "extra_args"), ( assert hasattr(new_req.sampling_params, "extra_args"), (
f"sampling_params missing extra_args for req {new_req.req_id}" f"sampling_params missing extra_args for req {new_req.req_id}"
) )
kv_transfer_params = new_req.sampling_params.extra_args[ kv_transfer_params = (
"kv_transfer_params" new_req.sampling_params.extra_args.get("kv_transfer_params", {})
] if new_req.sampling_params.extra_args
else {}
)
meta.add_new_req( meta.add_new_req(
red_id, red_id,
local_block_ids, local_block_ids,
@ -1212,7 +1228,7 @@ class MoRIIOConnectorScheduler:
meta.add_new_req( meta.add_new_req(
request_id=req_id, request_id=req_id,
local_block_ids=self._reqs_need_pending_save[req_id][1], local_block_ids=self._reqs_need_pending_save[req_id][1],
kv_transfer_params=req.kv_transfer_params, kv_transfer_params=req.kv_transfer_params or {},
write_mode=True, write_mode=True,
) )
del self._reqs_need_pending_save[req_id] del self._reqs_need_pending_save[req_id]
@ -1328,6 +1344,9 @@ class MoRIIOConnectorWorker:
# Config. # Config.
self.vllm_config = vllm_config self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config self.kv_transfer_config = vllm_config.kv_transfer_config
self.is_producer = self.kv_transfer_config.is_kv_producer self.is_producer = self.kv_transfer_config.is_kv_producer