diff --git a/vllm/v1/sample/ops/utils.py b/vllm/v1/sample/ops/utils.py deleted file mode 100644 index a54e20603064f..0000000000000 --- a/vllm/v1/sample/ops/utils.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from typing import Union - -import torch - - -def compiled_softmax( - logits: torch.Tensor, - temperature: Union[float, torch.Tensor] = 1.0, -) -> torch.Tensor: - """Faster softmax kernel generated by torch.compile. - - Args: - logits: [n, vocab_size] - temperature: [n] or float - """ - # NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic. - torch._dynamo.mark_dynamic(logits, index=0) - if isinstance(temperature, torch.Tensor): - torch._dynamo.mark_dynamic(temperature, index=0) - return _softmax(logits, temperature) - - -@torch.compile -def _softmax( - logits: torch.Tensor, - temperature: Union[float, torch.Tensor], -) -> torch.Tensor: - logits = logits / temperature - return torch.softmax(logits, dim=-1, dtype=torch.float32) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index c8327f36a5853..e0db9474f61cb 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -9,7 +9,6 @@ import triton.language as tl from vllm.logger import init_logger from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p -from vllm.v1.sample.ops.utils import compiled_softmax from vllm.v1.spec_decode.metadata import SpecDecodeMetadata logger = init_logger(__name__) @@ -275,8 +274,7 @@ def compute_probs( # NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask, # which is slow for large vocab sizes. This may cause performance issues. logits = apply_top_k_top_p(logits, top_k, top_p) - - output_prob = compiled_softmax(logits) + output_prob = logits.softmax(dim=-1, dtype=torch.float32) return output_prob