[Model Runner V2] Fix Triton warning on tl.where (#30355)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-12-09 09:59:54 -08:00 committed by GitHub
parent 0b6a8a304c
commit 9e6562a3f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,6 +62,7 @@ def _penalties_and_temperature_kernel(
mask=packed_block < tl.cdiv(vocab_size, 32),
)
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.