mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:47:00 +08:00
[bugfix] [spec-decoding] fix data race in sample_recovered_tokens_kernel (vLLM v1) (#23829)
Signed-off-by: He-Jingkai <he-jingkai@outlook.com>
This commit is contained in:
parent
04d1dd7f4a
commit
57d4ede520
@ -598,17 +598,10 @@ def sample_recovered_tokens_kernel(
|
|||||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||||
if NO_DRAFT_PROBS:
|
if NO_DRAFT_PROBS:
|
||||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||||
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
|
||||||
draft_token_id)
|
|
||||||
# Temporarily zero out the probability of the draft token.
|
|
||||||
# This is essentially the same as target_prob - draft_prob, except that
|
|
||||||
# n-gram does not have draft_prob. We regard it as 1.
|
|
||||||
tl.store(
|
|
||||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
|
||||||
0)
|
|
||||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||||
vocab_offset,
|
vocab_offset,
|
||||||
mask=vocab_offset < vocab_size,
|
mask=((vocab_offset < vocab_size) &
|
||||||
|
(vocab_offset != draft_token_id)),
|
||||||
other=0)
|
other=0)
|
||||||
else:
|
else:
|
||||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||||
@ -628,9 +621,3 @@ def sample_recovered_tokens_kernel(
|
|||||||
other=float("-inf"))
|
other=float("-inf"))
|
||||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||||
|
|
||||||
if NO_DRAFT_PROBS:
|
|
||||||
# Restore the original probability.
|
|
||||||
tl.store(
|
|
||||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,
|
|
||||||
orig_prob)
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user