diff --git a/vllm/v1/worker/gpu/sampler.py b/vllm/v1/worker/gpu/sampler.py index d03f6edcdec97..c7c7c3c32ad2f 100644 --- a/vllm/v1/worker/gpu/sampler.py +++ b/vllm/v1/worker/gpu/sampler.py @@ -47,7 +47,6 @@ class Sampler(nn.Module): sampling_metadata.seeds, sampling_metadata.pos, ) - sampled = sampled.unsqueeze(-1) logprobs_tensors = None num_logprobs = sampling_metadata.max_num_logprobs @@ -63,7 +62,7 @@ class Sampler(nn.Module): # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. - sampled_token_ids=sampled.unsqueeze(-1), + sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, ) return sampler_output @@ -220,7 +219,8 @@ def _topk_logprobs_kernel( block = i + tl.arange(0, BLOCK_SIZE) l = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) e = tl.exp(l - max_val) - se += tl.sum(tl.where(block < vocab_size, e, 0.0)) + e = tl.where(block < vocab_size, e, 0.0) + se += tl.sum(e) lse = tl.log(se) k_offset = tl.arange(0, PADDED_TOPK)