mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 08:37:02 +08:00
format
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
b3b195a540
commit
ad5678b056
@ -240,7 +240,9 @@ class MoRIIOConfig:
|
||||
|
||||
# TODO : merge notify_port and handshake_port to simplify port management
|
||||
# supports non-contiguous ports
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
@ -745,12 +747,12 @@ class MoRIIOWrapper:
|
||||
else:
|
||||
self.done_write_cache_req_ids.append(msg)
|
||||
|
||||
def send_notify(self, req_ids, remote_ip=None, remote_port=None):
|
||||
def send_notify(self, req_ids, remote_ip, remote_port):
|
||||
if not remote_ip or not remote_port:
|
||||
logger.warning("Missing remote_ip or remote_port for notification")
|
||||
return
|
||||
|
||||
path = make_zmq_path("tcp", remote_ip, str(remote_port))
|
||||
path = make_zmq_path("tcp", remote_ip, remote_port)
|
||||
|
||||
if path not in self.paths:
|
||||
ctx = zmq.Context.instance()
|
||||
@ -872,18 +874,18 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||
):
|
||||
super().__init__(vllm_config, role)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
# assert vllm_config.kv_transfer_config.engine_id is not None
|
||||
self._set_port_defaults(vllm_config)
|
||||
|
||||
self.engine_id = (
|
||||
str(get_ip())
|
||||
+ ":"
|
||||
+ str(
|
||||
vllm_config.kv_transfer_config.kv_connector_extra_config[
|
||||
"handshake_port"
|
||||
]
|
||||
)
|
||||
+ str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"])
|
||||
)
|
||||
self.mode = get_moriio_mode()
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
@ -905,6 +907,9 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
||||
############################################################
|
||||
|
||||
def _set_port_defaults(self, vllm_config: VllmConfig):
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
@ -1011,23 +1016,26 @@ class MoRIIOConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id: EngineId = engine_id
|
||||
self.mode = get_moriio_mode()
|
||||
self.host_ip = get_ip()
|
||||
self.handshake_port = (
|
||||
self.vllm_config.kv_transfer_config.kv_connector_extra_config[
|
||||
"handshake_port"
|
||||
]
|
||||
)
|
||||
self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[
|
||||
"handshake_port"
|
||||
]
|
||||
logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id)
|
||||
|
||||
self.side_notify_port = (
|
||||
self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"]
|
||||
)
|
||||
self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[
|
||||
"notify_port"
|
||||
]
|
||||
self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
|
||||
self.is_producer = self.kv_transfer_config.kv_role == "kv_producer"
|
||||
# Requests that need to start recv/send.
|
||||
# New requests are added by update_state_after_alloc in
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
@ -1045,7 +1053,6 @@ class MoRIIOConnectorScheduler:
|
||||
# Reqs to send and their expiration time
|
||||
self._reqs_need_send: dict[ReqId, float] = {}
|
||||
self.sock = None
|
||||
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
|
||||
self.paths: dict[str, zmq.Socket] = {}
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
@ -1070,12 +1077,13 @@ class MoRIIOConnectorScheduler:
|
||||
if self.is_producer:
|
||||
return 0, False
|
||||
|
||||
token_ids = request.prompt_token_ids or []
|
||||
if self.mode == MoRIIOMode.WRITE:
|
||||
# MoriiO in write mode, no remote prefill
|
||||
|
||||
return len(request.prompt_token_ids) - num_computed_tokens, True
|
||||
return len(token_ids) - num_computed_tokens, True
|
||||
|
||||
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False
|
||||
return len(token_ids) - 1 - num_computed_tokens, False
|
||||
|
||||
def send_notify_block(
|
||||
self, req_id: str, block_notify_list: list[int], host=None, port=None
|
||||
@ -1105,6 +1113,8 @@ class MoRIIOConnectorScheduler:
|
||||
connector_worker: Optional["MoRIIOConnectorWorker"] = None,
|
||||
):
|
||||
params = request.kv_transfer_params
|
||||
if not params:
|
||||
return
|
||||
if params.get("do_remote_decode"):
|
||||
local_block_ids = blocks.get_block_ids()[0]
|
||||
self._reqs_need_save[request.request_id] = (request, local_block_ids)
|
||||
@ -1140,6 +1150,10 @@ class MoRIIOConnectorScheduler:
|
||||
)
|
||||
|
||||
else:
|
||||
assert request.kv_transfer_params is not None, (
|
||||
"kv_transfer_params should not be None"
|
||||
)
|
||||
|
||||
remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0)
|
||||
|
||||
for tp_index in range(self.tp_size):
|
||||
@ -1178,9 +1192,11 @@ class MoRIIOConnectorScheduler:
|
||||
assert hasattr(new_req.sampling_params, "extra_args"), (
|
||||
f"sampling_params missing extra_args for req {new_req.req_id}"
|
||||
)
|
||||
kv_transfer_params = new_req.sampling_params.extra_args[
|
||||
"kv_transfer_params"
|
||||
]
|
||||
kv_transfer_params = (
|
||||
new_req.sampling_params.extra_args.get("kv_transfer_params", {})
|
||||
if new_req.sampling_params.extra_args
|
||||
else {}
|
||||
)
|
||||
meta.add_new_req(
|
||||
red_id,
|
||||
local_block_ids,
|
||||
@ -1212,7 +1228,7 @@ class MoRIIOConnectorScheduler:
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=self._reqs_need_pending_save[req_id][1],
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
kv_transfer_params=req.kv_transfer_params or {},
|
||||
write_mode=True,
|
||||
)
|
||||
del self._reqs_need_pending_save[req_id]
|
||||
@ -1328,6 +1344,9 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
# Config.
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
"kv_transfer_config must be set for MoRIIOConnector"
|
||||
)
|
||||
self.kv_transfer_config = vllm_config.kv_transfer_config
|
||||
self.is_producer = self.kv_transfer_config.is_kv_producer
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user