Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-28 15:02:07 -07:00
parent 405578121c
commit 9ee9d0e274
2 changed files with 12 additions and 6 deletions

View File

@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
self._draft_req_ids: Optional[list[str]] = None
self.transfer_event = torch.cuda.Event()
self.sampled_token_ids_pinned_cpu = torch.empty(
(self.max_model_len, 1),
@ -997,13 +998,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
req_data = self.requests.req_data[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
mm_hash = req_data.mm_hashes[mm_input_id]
mm_kwargs.append(req_data.mm_kwargs[mm_input_id])
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
(mm_hash, req_data.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@ -1576,6 +1577,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.spec_decode_metadata,
input_batch.spec_decode_common_attn_metadata,
)
self._draft_req_ids = input_batch.req_ids
self.eplb_step()
@ -1593,12 +1595,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.requests.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
assert self._draft_req_ids
req_ids = self._draft_req_ids
self._draft_req_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
def propose_draft_token_ids(
@ -1614,7 +1619,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
input_batch, sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
if sample_hidden_states.shape[0] == len(sampled_token_ids):

View File

@ -28,6 +28,7 @@ class RequestData:
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
mm_hashes: list[str]
# M-RoPE (only for Qwen2/2.5-VL)
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None