Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-23 10:47:17 +00:00
parent 8ef0b5f509
commit d2a18332b7
2 changed files with 18 additions and 11 deletions

View File

@ -210,11 +210,16 @@ async def handle_request():
prefill_instance_endpoint = None prefill_instance_endpoint = None
decode_instance_endpoint = None decode_instance_endpoint = None
error_msg = (
"Service Unavailable: No prefill or decode instances are registered."
)
if not prefill_instances or not decode_instances: if not prefill_instances or not decode_instances:
return await make_response( return await make_response(
("Service Unavailable: No prefill or decode instances are registered.", (
503)) error_msg,
503,
)
)
pid = request_nums % len(prefill_instances) pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances) did = request_nums % len(decode_instances)
prefill_instance_endpoint = prefill_instances[pid] prefill_instance_endpoint = prefill_instances[pid]
@ -297,10 +302,12 @@ async def handle_request():
return response return response
except Exception as e: except Exception as e:
logger.exception("An error occurred while handling the request: %s", e) logger.exception("An error occurred while handling the request: %s", e)
return await make_response(( return await make_response(
f"Internal Server Error: {e!s}", (
500, f"Internal Server Error: {e!s}",
)) 500,
)
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -342,8 +342,8 @@ class MoRIIOConnectorScheduler:
local_block_ids = blocks.get_block_ids()[0] local_block_ids = blocks.get_block_ids()[0]
self._reqs_need_save[request.request_id] = (request, local_block_ids) self._reqs_need_save[request.request_id] = (request, local_block_ids)
if params is not None and params.get("do_remote_prefill"): # if params is not None and params.get("do_remote_prefill"):
if self.mode == MoRIIOMode.READ: #read mode decode if self.mode == MoRIIOMode.READ:
if remote_block_ids := params.get("remote_block_ids"): if remote_block_ids := params.get("remote_block_ids"):
if all( if all(
p in params p in params
@ -373,7 +373,7 @@ class MoRIIOConnectorScheduler:
) )
else: else:
assert request.kv_transfer_params is not None, ( #write mode decode assert request.kv_transfer_params is not None, (
"kv_transfer_params should not be None" "kv_transfer_params should not be None"
) )
@ -890,7 +890,7 @@ class MoRIIOConnectorWorker:
layer_name_to_local_kv_cache_metadata: dict, layer_name_to_local_kv_cache_metadata: dict,
): ):
"""Background thread for getting new MoRIIO handshakes.""" """Background thread for getting new MoRIIO handshakes."""
logger.info("tmp")
encoder = msgspec.msgpack.Encoder() encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata) encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data) size_in_bytes = len(encoded_data)