mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:44:27 +08:00
[BugFix] Fix async scheduling CPU tensor race take 2 (#25279)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
ee7a66dd9a
commit
14c1432789
@ -1903,7 +1903,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
**self._init_model_kwargs(num_scheduled_tokens),
|
**self._init_model_kwargs(num_scheduled_tokens),
|
||||||
**self._extract_mm_kwargs(scheduler_output),
|
**self._extract_mm_kwargs(scheduler_output),
|
||||||
}
|
}
|
||||||
elif (self.enable_prompt_embeds and get_pp_group().is_first_rank):
|
elif self.enable_prompt_embeds and get_pp_group().is_first_rank:
|
||||||
# Get the input embeddings for the tokens that are not input embeds,
|
# Get the input embeddings for the tokens that are not input embeds,
|
||||||
# then put them into the appropriate positions.
|
# then put them into the appropriate positions.
|
||||||
# TODO(qthequartermasterman): Since even when prompt embeds are
|
# TODO(qthequartermasterman): Since even when prompt embeds are
|
||||||
@ -2125,6 +2125,21 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
invalid_req_indices,
|
invalid_req_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def synchronize_input_prep(self):
|
||||||
|
if self.prepare_inputs_event is None:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ensure prior step has finished with reused CPU tensors.
|
||||||
|
# This is required in the async scheduling case because
|
||||||
|
# the CPU->GPU transfer happens async.
|
||||||
|
self.prepare_inputs_event.synchronize()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.prepare_inputs_event.record()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -2132,33 +2147,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]:
|
||||||
with record_function_or_nullcontext("Preprocess"):
|
with record_function_or_nullcontext("Preprocess"):
|
||||||
self._update_states(scheduler_output)
|
with self.synchronize_input_prep():
|
||||||
if not scheduler_output.total_num_scheduled_tokens:
|
# Update persistent batch states.
|
||||||
if not has_kv_transfer_group():
|
self._update_states(scheduler_output)
|
||||||
# Return empty ModelRunnerOutput if there's no work to do.
|
|
||||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
if not scheduler_output.total_num_scheduled_tokens:
|
||||||
return self.kv_connector_no_forward(scheduler_output,
|
if not has_kv_transfer_group():
|
||||||
self.vllm_config)
|
# Return empty ModelRunnerOutput if no work to do.
|
||||||
if self.cache_config.kv_sharing_fast_prefill:
|
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||||
assert not self.input_batch.num_prompt_logprobs, (
|
return self.kv_connector_no_forward(
|
||||||
"--kv-sharing-fast-prefill produces incorrect logprobs for "
|
scheduler_output, self.vllm_config)
|
||||||
"prompt tokens, tokens, please disable it when the requests"
|
if self.cache_config.kv_sharing_fast_prefill:
|
||||||
" need prompt logprobs")
|
assert not self.input_batch.num_prompt_logprobs, (
|
||||||
|
"--kv-sharing-fast-prefill produces incorrect "
|
||||||
|
"logprobs for prompt tokens, tokens, please disable "
|
||||||
|
"it when the requests need prompt logprobs")
|
||||||
|
|
||||||
if self.prepare_inputs_event is not None:
|
|
||||||
# Ensure prior step has finished with reused CPU tensors.
|
|
||||||
self.prepare_inputs_event.synchronize()
|
|
||||||
try:
|
|
||||||
# Prepare the decoder inputs.
|
# Prepare the decoder inputs.
|
||||||
(attn_metadata, logits_indices, spec_decode_metadata,
|
(attn_metadata, logits_indices, spec_decode_metadata,
|
||||||
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
num_scheduled_tokens_np, spec_decode_common_attn_metadata,
|
||||||
max_query_len, ubatch_slices, num_tokens_after_padding
|
max_query_len, ubatch_slices, num_tokens_after_padding
|
||||||
) = self._prepare_inputs(scheduler_output)
|
) = self._prepare_inputs(scheduler_output)
|
||||||
|
|
||||||
finally:
|
|
||||||
if self.prepare_inputs_event is not None:
|
|
||||||
self.prepare_inputs_event.record()
|
|
||||||
|
|
||||||
(
|
(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
num_input_tokens,
|
num_input_tokens,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user