[BugFix] Fix handling of resumed reqs in SharedStorageConnector (#27719)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-29 13:16:52 -07:00 committed by GitHub
parent fcb1d570bb
commit d4aa144343
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -336,36 +336,34 @@ class SharedStorageConnector(KVConnectorBase_V1):
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids): for i, req_id in enumerate(cached_reqs.req_ids):
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
if not resumed_from_preemption or req_id not in self._requests_need_load:
continue
num_computed_tokens = cached_reqs.num_computed_tokens[i] num_computed_tokens = cached_reqs.num_computed_tokens[i]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
new_block_ids = cached_reqs.new_block_ids[i] new_block_ids = cached_reqs.new_block_ids[i]
resumed_from_preemption = cached_reqs.resumed_from_preemption[i]
# NOTE(rob): here we rely on the resumed requests being # NOTE(rob): cached_req_data does not have the full
# the first N requests in the list scheduled_cache_reqs. # list of token ids (only new tokens). So we look it
if not resumed_from_preemption: # up in the actual request object.
break request = self._requests_need_load[req_id]
if req_id in self._requests_need_load: total_tokens = num_computed_tokens + num_new_tokens
# NOTE(rob): cached_req_data does not have the full token_ids = request.all_token_ids[:total_tokens]
# list of token ids (only new tokens). So we look it
# up in the actual request object.
request = self._requests_need_load[req_id]
total_tokens = num_computed_tokens + num_new_tokens
token_ids = request.all_token_ids[:total_tokens]
# NOTE(rob): For resumed req, new_block_ids is all # NOTE(rob): For resumed req, new_block_ids is all
# of the block_ids for the request. # of the block_ids for the request.
assert new_block_ids is not None assert new_block_ids is not None
block_ids = new_block_ids[0] block_ids = new_block_ids[0]
meta.add_request( meta.add_request(
token_ids=token_ids, token_ids=token_ids,
block_ids=block_ids, block_ids=block_ids,
block_size=self._block_size, block_size=self._block_size,
is_store=False, is_store=False,
mm_hashes=[f.identifier for f in request.mm_features], mm_hashes=[f.identifier for f in request.mm_features],
) )
total_need_load += 1 total_need_load += 1
assert total_need_load == len(self._requests_need_load) assert total_need_load == len(self._requests_need_load)
self._requests_need_load.clear() self._requests_need_load.clear()