diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py index 5135580924f14..b1063f5487955 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/moriio_connector.py @@ -240,7 +240,9 @@ class MoRIIOConfig: # TODO : merge notify_port and handshake_port to simplify port management # supports non-contiguous ports - + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config tp_rank = get_tensor_model_parallel_rank() @@ -745,12 +747,12 @@ class MoRIIOWrapper: else: self.done_write_cache_req_ids.append(msg) - def send_notify(self, req_ids, remote_ip=None, remote_port=None): + def send_notify(self, req_ids, remote_ip, remote_port): if not remote_ip or not remote_port: logger.warning("Missing remote_ip or remote_port for notification") return - path = make_zmq_path("tcp", remote_ip, str(remote_port)) + path = make_zmq_path("tcp", remote_ip, remote_port) if path not in self.paths: ctx = zmq.Context.instance() @@ -872,18 +874,18 @@ class MoRIIOConnector(KVConnectorBase_V1): kv_cache_config: Optional["KVCacheConfig"] = None, ): super().__init__(vllm_config, role) - assert vllm_config.kv_transfer_config is not None + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + + self.kv_transfer_config = vllm_config.kv_transfer_config # assert vllm_config.kv_transfer_config.engine_id is not None self._set_port_defaults(vllm_config) self.engine_id = ( str(get_ip()) + ":" - + str( - vllm_config.kv_transfer_config.kv_connector_extra_config[ - "handshake_port" - ] - ) + + str(self.kv_transfer_config.kv_connector_extra_config["handshake_port"]) ) self.mode = get_moriio_mode() if role == KVConnectorRole.SCHEDULER: @@ -905,6 +907,9 @@ class MoRIIOConnector(KVConnectorBase_V1): ############################################################ def _set_port_defaults(self, vllm_config: VllmConfig): + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) kv_transfer_config = vllm_config.kv_transfer_config extra_config = kv_transfer_config.kv_connector_extra_config @@ -1011,23 +1016,26 @@ class MoRIIOConnectorScheduler: def __init__(self, vllm_config: VllmConfig, engine_id: str): self.vllm_config = vllm_config + + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) + self.kv_transfer_config = vllm_config.kv_transfer_config self.block_size = vllm_config.cache_config.block_size self.engine_id: EngineId = engine_id self.mode = get_moriio_mode() self.host_ip = get_ip() - self.handshake_port = ( - self.vllm_config.kv_transfer_config.kv_connector_extra_config[ - "handshake_port" - ] - ) + self.handshake_port = self.kv_transfer_config.kv_connector_extra_config[ + "handshake_port" + ] logger.info("Initializing MoRIIO Scheduler engine_id = %s", engine_id) - self.side_notify_port = ( - self.vllm_config.kv_transfer_config.kv_connector_extra_config["notify_port"] - ) + self.side_notify_port = self.kv_transfer_config.kv_connector_extra_config[ + "notify_port" + ] self.tp_size = self.vllm_config.parallel_config.tensor_parallel_size self.dp_rank = self.vllm_config.parallel_config.data_parallel_rank - self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" + self.is_producer = self.kv_transfer_config.kv_role == "kv_producer" # Requests that need to start recv/send. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -1045,7 +1053,6 @@ class MoRIIOConnectorScheduler: # Reqs to send and their expiration time self._reqs_need_send: dict[ReqId, float] = {} self.sock = None - self.is_producer = vllm_config.kv_transfer_config.kv_role == "kv_producer" self.paths: dict[str, zmq.Socket] = {} def get_num_new_matched_tokens( @@ -1070,12 +1077,13 @@ class MoRIIOConnectorScheduler: if self.is_producer: return 0, False + token_ids = request.prompt_token_ids or [] if self.mode == MoRIIOMode.WRITE: # MoriiO in write mode, no remote prefill - return len(request.prompt_token_ids) - num_computed_tokens, True + return len(token_ids) - num_computed_tokens, True - return len(request.prompt_token_ids) - 1 - num_computed_tokens, False + return len(token_ids) - 1 - num_computed_tokens, False def send_notify_block( self, req_id: str, block_notify_list: list[int], host=None, port=None @@ -1105,6 +1113,8 @@ class MoRIIOConnectorScheduler: connector_worker: Optional["MoRIIOConnectorWorker"] = None, ): params = request.kv_transfer_params + if not params: + return if params.get("do_remote_decode"): local_block_ids = blocks.get_block_ids()[0] self._reqs_need_save[request.request_id] = (request, local_block_ids) @@ -1140,6 +1150,10 @@ class MoRIIOConnectorScheduler: ) else: + assert request.kv_transfer_params is not None, ( + "kv_transfer_params should not be None" + ) + remote_dp_rank = request.kv_transfer_params.get("remote_dp_rank", 0) for tp_index in range(self.tp_size): @@ -1178,9 +1192,11 @@ class MoRIIOConnectorScheduler: assert hasattr(new_req.sampling_params, "extra_args"), ( f"sampling_params missing extra_args for req {new_req.req_id}" ) - kv_transfer_params = new_req.sampling_params.extra_args[ - "kv_transfer_params" - ] + kv_transfer_params = ( + new_req.sampling_params.extra_args.get("kv_transfer_params", {}) + if new_req.sampling_params.extra_args + else {} + ) meta.add_new_req( red_id, local_block_ids, @@ -1212,7 +1228,7 @@ class MoRIIOConnectorScheduler: meta.add_new_req( request_id=req_id, local_block_ids=self._reqs_need_pending_save[req_id][1], - kv_transfer_params=req.kv_transfer_params, + kv_transfer_params=req.kv_transfer_params or {}, write_mode=True, ) del self._reqs_need_pending_save[req_id] @@ -1328,6 +1344,9 @@ class MoRIIOConnectorWorker: # Config. self.vllm_config = vllm_config + assert vllm_config.kv_transfer_config is not None, ( + "kv_transfer_config must be set for MoRIIOConnector" + ) self.kv_transfer_config = vllm_config.kv_transfer_config self.is_producer = self.kv_transfer_config.is_kv_producer