Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 10:46:37 +00:00
parent bba4c89ca4
commit 08cd2efbb6

View File

@ -694,9 +694,9 @@ class MoRIIOWrapper:
self._handle_completion_message(msg_str) self._handle_completion_message(msg_str)
handled = True handled = True
except UnicodeDecodeError: except UnicodeDecodeError:
logger.warning(f"Received non-UTF8 message: {msg}") logger.warning(f"Received non-UTF8 message: {msg_str}")
if not handled: if not handled:
raise MoRIIOError(f"Unhandled message format: {msg}") raise MoRIIOError(f"Unhandled message format: {msg_str}")
def _handle_structured_message(self, data: dict): def _handle_structured_message(self, data: dict):
req_id = data["req_id"] req_id = data["req_id"]
@ -784,7 +784,7 @@ class ReqMeta:
remote_host: str remote_host: str
remote_port: int remote_port: int
remote_handshake_port: int remote_handshake_port: int
remote_notify_port: int remote_notify_port: int | None
remote_engine_id: str remote_engine_id: str
tp_size: int tp_size: int
remote_dp_size: int remote_dp_size: int
@ -1011,7 +1011,7 @@ class MoRIIOConnectorScheduler:
self._reqs_need_send: dict[ReqId, float] = {} self._reqs_need_send: dict[ReqId, float] = {}
self.sock = None self.sock = None
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
self.paths = {} self.paths: dict[str, zmq.Socket] = {}
def get_num_new_matched_tokens( def get_num_new_matched_tokens(
self, self,
@ -1043,7 +1043,7 @@ class MoRIIOConnectorScheduler:
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False return len(request.prompt_token_ids) - 1 - num_computed_tokens, False
def send_notify_block( def send_notify_block(
self, req_id: str, block_notify_list: list[int] = None, host=None, port=None self, req_id: str, block_notify_list: list[int] , host=None, port=None
): ):
path = make_zmq_path("tcp", host, port) path = make_zmq_path("tcp", host, port)
if path not in self.paths: if path not in self.paths:
@ -1374,25 +1374,24 @@ class MoRIIOConnectorWorker:
self.moriio_wrapper.set_moriio_engine(self.moriio_engine) self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
self.moriio_wrapper.set_backend_type(BackendType.RDMA) self.moriio_wrapper.set_backend_type(BackendType.RDMA)
self.moriio_wrapper.notify_port = self.moriio_config.notify_port self.moriio_wrapper.notify_port = self.moriio_config.notify_port
self.local_kv_cache_metadata = [] self.local_kv_cache_metadata: list[bytes] = []
self.local_kv_cache_size = [] self.local_kv_cache_size: list[int] = []
self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict() self.layer_name_to_local_kv_cache_metadata: dict[str, list[bytes]] = {}
self.remote_kv_cache_metadata = [] self.remote_kv_cache_metadata: list[bytes] = []
self.remote_kv_cache_size = [] self.remote_kv_cache_size: list[int] = []
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = ( self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
dict() dict()
) )
self.slot_size_bytes = 0 self.slot_size_bytes = 0
self.load_ready_flag = {} self.load_ready_flag: dict[str, bool] = {}
self.write_ready_flags = {} self.write_ready_flags: dict[str, bool] = {}
self.kv_cache_shape = None self.kv_cache_shape = None
self.block_shape = None self.block_shape = None
self.kv_element_size = 0 self.kv_element_size = 0
self.done_sending_reqs = []
self.done_send_threads = []
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}. # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict) self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -1452,7 +1451,6 @@ class MoRIIOConnectorWorker:
self.use_mla = self.model_config.use_mla self.use_mla = self.model_config.use_mla
self.built_session = False self.built_session = False
self.built_write_session: defaultdict[str, list] = defaultdict(list) self.built_write_session: defaultdict[str, list] = defaultdict(list)
self.debug_cache = []
backend = get_attn_backend( backend = get_attn_backend(
self.model_config.get_head_size(), self.model_config.get_head_size(),
self.model_config.dtype, self.model_config.dtype,