Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 10:46:37 +00:00
parent bba4c89ca4
commit 08cd2efbb6

View File

@ -694,9 +694,9 @@ class MoRIIOWrapper:
self._handle_completion_message(msg_str)
handled = True
except UnicodeDecodeError:
logger.warning(f"Received non-UTF8 message: {msg}")
logger.warning(f"Received non-UTF8 message: {msg_str}")
if not handled:
raise MoRIIOError(f"Unhandled message format: {msg}")
raise MoRIIOError(f"Unhandled message format: {msg_str}")
def _handle_structured_message(self, data: dict):
req_id = data["req_id"]
@ -784,7 +784,7 @@ class ReqMeta:
remote_host: str
remote_port: int
remote_handshake_port: int
remote_notify_port: int
remote_notify_port: int | None
remote_engine_id: str
tp_size: int
remote_dp_size: int
@ -1011,7 +1011,7 @@ class MoRIIOConnectorScheduler:
self._reqs_need_send: dict[ReqId, float] = {}
self.sock = None
self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer"
self.paths = {}
self.paths: dict[str, zmq.Socket] = {}
def get_num_new_matched_tokens(
self,
@ -1043,7 +1043,7 @@ class MoRIIOConnectorScheduler:
return len(request.prompt_token_ids) - 1 - num_computed_tokens, False
def send_notify_block(
self, req_id: str, block_notify_list: list[int] = None, host=None, port=None
self, req_id: str, block_notify_list: list[int] , host=None, port=None
):
path = make_zmq_path("tcp", host, port)
if path not in self.paths:
@ -1374,25 +1374,24 @@ class MoRIIOConnectorWorker:
self.moriio_wrapper.set_moriio_engine(self.moriio_engine)
self.moriio_wrapper.set_backend_type(BackendType.RDMA)
self.moriio_wrapper.notify_port = self.moriio_config.notify_port
self.local_kv_cache_metadata = []
self.local_kv_cache_size = []
self.layer_name_to_local_kv_cache_metadata: dict[str, list[Any]] = dict()
self.local_kv_cache_metadata: list[bytes] = []
self.local_kv_cache_size: list[int] = []
self.layer_name_to_local_kv_cache_metadata: dict[str, list[bytes]] = {}
self.remote_kv_cache_metadata = []
self.remote_kv_cache_size = []
self.remote_kv_cache_metadata: list[bytes] = []
self.remote_kv_cache_size: list[int] = []
self.layer_name_to_remote_kv_cache_metadata: dict[str, dict[str, list[Any]]] = (
dict()
)
self.slot_size_bytes = 0
self.load_ready_flag = {}
self.write_ready_flags = {}
self.load_ready_flag: dict[str, bool] = {}
self.write_ready_flags: dict[str, bool] = {}
self.kv_cache_shape = None
self.block_shape = None
self.kv_element_size = 0
self.done_sending_reqs = []
self.done_send_threads = []
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -1452,7 +1451,6 @@ class MoRIIOConnectorWorker:
self.use_mla = self.model_config.use_mla
self.built_session = False
self.built_write_session: defaultdict[str, list] = defaultdict(list)
self.debug_cache = []
backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.dtype,