Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-27 07:23:48 +00:00
parent b3b195a540
commit ad5678b056

View File

@ -240,7 +240,9 @@ class MoRIIOConfig:
# TODO : merge notify_port and handshake_port to simplify port management
# supports non-contiguous ports
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config
tp_rank = get_tensor_model_parallel_rank()
@ -745,12 +747,12 @@ class MoRIIOWrapper:
else:
self.done_write_cache_req_ids.append(msg)
def send_notify(self, req_ids, remote_ip=None, remote_port=None):
def send_notify(self, req_ids, remote_ip, remote_port):
if not remote_ip or not remote_port:
logger.warning("Missing remote_ip or remote_port for notification")
return
path = make_zmq_path("tcp", remote_ip, str(remote_port))
path = make_zmq_path("tcp", remote_ip, remote_port)
if path not in self.paths:
ctx = zmq.Context.instance()
@ -872,18 +874,18 @@ class MoRIIOConnector(KVConnectorBase_V1):
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config
# assert vllm_config.kv_transfer_config.engine_id is not None
self._set_port_defaults(vllm_config)
self.engine_id = (
str(get_ip())
+ ":"
+ str(
vllm_config.kv_transfer_config.kv_connector_extra_config[
"handshake_port"
]
)
+ str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"])
)
self.mode = get_moriio_mode()
if role == KVConnectorRole.SCHEDULER:
@ -905,6 +907,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
############################################################
def _set_port_defaults(self, vllm_config: VllmConfig):
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
kv_transfer_config = vllm_config.kv_transfer_config
extra_config = kv_transfer_config.kv_connector_extra_config
@ -1011,23 +1016,26 @@ class MoRIIOConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id: EngineId = engine_id
self.mode = get_moriio_mode()
self.host_ip = get_ip()
self.handshake_port = (
self.vllm_config.kv_transfer_config.kv_connector_extra_config[
"handshake_port"
]
)
self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[
"handshake_port"
]
logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id)
self.side_notify_port = (
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
)
self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[
"notify_port"
]
self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size
self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
self.is_producer = self.kv_transfer_config.kv_role == "kv_producer"
# Requests that need to start recv/send.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
@ -1045,7 +1053,6 @@ class MoRIIOConnectorScheduler:
# Reqs to send and their expiration time
self._reqs_need_send: dict[ReqId, float] = {}
self.sock = None
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
self.paths: dict[str, zmq.Socket] = {}
def get_num_new_matched_tokens(
@ -1070,12 +1077,13 @@ class MoRIIOConnectorScheduler:
if self.is_producer:
return 0, False
token_ids = request.prompt_token_ids or []
if self.mode == MoRIIOMode.WRITE:
# MoriiO in write mode, no remote prefill
return len(request.prompt_token_ids) - num_computed_tokens, True
return len(token_ids) - num_computed_tokens, True
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False
return len(token_ids) - 1 - num_computed_tokens, False
def send_notify_block(
self, req_id: str, block_notify_list: list[int], host=None, port=None
@ -1105,6 +1113,8 @@ class MoRIIOConnectorScheduler:
connector_worker: Optional["MoRIIOConnectorWorker"] = None,
):
params = request.kv_transfer_params
if not params:
return
if params.get("do_remote_decode"):
local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids)
@ -1140,6 +1150,10 @@ class MoRIIOConnectorScheduler:
)
else:
assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None"
)
remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)
for tp_index in range(self.tp_size):
@ -1178,9 +1192,11 @@ class MoRIIOConnectorScheduler:
assert hasattr(new_req.sampling_params, "extra_args"), (
f"sampling_params missing extra_args for req {new_req.req_id}"
)
kv_transfer_params = new_req.sampling_params.extra_args[
"kv_transfer_params"
]
kv_transfer_params = (
new_req.sampling_params.extra_args.get("kv_transfer_params", {})
if new_req.sampling_params.extra_args
else {}
)
meta.add_new_req(
red_id,
local_block_ids,
@ -1212,7 +1228,7 @@ class MoRIIOConnectorScheduler:
meta.add_new_req(
request_id=req_id,
local_block_ids=self._reqs_need_pending_save[req_id][1],
kv_transfer_params=req.kv_transfer_params,
kv_transfer_params=req.kv_transfer_params or {},
write_mode=True,
)
del self._reqs_need_pending_save[req_id]
@ -1328,6 +1344,9 @@ class MoRIIOConnectorWorker:
# Config.
self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
"kv_transfer_config must be set for MoRIIOConnector"
)
self.kv_transfer_config = vllm_config.kv_transfer_config
self.is_producer = self.kv_transfer_config.is_kv_producer