mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 00:02:17 +08:00
[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)
Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
parent
6edbfa924d
commit
47195057e9
@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
|
|||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
# top_p=0.6, # TODO too slow!
|
# top_p=0.6, # TODO too slow!
|
||||||
# top_k=10,
|
top_k=10,
|
||||||
min_p=0.2,
|
min_p=0.2,
|
||||||
max_tokens=16)
|
max_tokens=16)
|
||||||
s = time()
|
s = time()
|
||||||
@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
|
|||||||
# Second request with different params, but for which we
|
# Second request with different params, but for which we
|
||||||
# compiled for in previous eager iteration.
|
# compiled for in previous eager iteration.
|
||||||
sampling_params = SamplingParams(temperature=0.1,
|
sampling_params = SamplingParams(temperature=0.1,
|
||||||
|
top_k=12,
|
||||||
min_p=0.8,
|
min_p=0.8,
|
||||||
max_tokens=24)
|
max_tokens=24)
|
||||||
s = time()
|
s = time()
|
||||||
|
|||||||
@ -95,6 +95,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DP_MASTER_PORT: int = 0
|
VLLM_DP_MASTER_PORT: int = 0
|
||||||
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
VLLM_MARLIN_USE_ATOMIC_ADD: bool = False
|
||||||
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
VLLM_V0_USE_OUTLINES_CACHE: bool = False
|
||||||
|
VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION: bool = False
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -623,6 +624,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
# an environment with potentially malicious users.
|
# an environment with potentially malicious users.
|
||||||
"VLLM_V0_USE_OUTLINES_CACHE":
|
"VLLM_V0_USE_OUTLINES_CACHE":
|
||||||
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
lambda: os.environ.get("VLLM_V0_USE_OUTLINES_CACHE", "0") == "1",
|
||||||
|
|
||||||
|
# If set, disables TPU-specific optimization for top-k & top-p sampling
|
||||||
|
"VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION":
|
||||||
|
lambda: bool(int(os.environ["VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION"]))
|
||||||
|
if "VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION" in os.environ else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@ -66,7 +66,14 @@ class TopKTopPSampler(nn.Module):
|
|||||||
"best performance, please install FlashInfer.")
|
"best performance, please install FlashInfer.")
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
elif current_platform.is_tpu():
|
elif current_platform.is_tpu():
|
||||||
self.forward = self.forward_tpu
|
if envs.VLLM_TPU_DISABLE_TOPK_TOPP_OPTIMIZATION:
|
||||||
|
logger.warning(
|
||||||
|
"TPU-specific optimization for top-k & top-p sampling are "
|
||||||
|
"disabled, falling back to PyTorch-native implementation "
|
||||||
|
"which could be very slow.")
|
||||||
|
self.forward = self.forward_native
|
||||||
|
else:
|
||||||
|
self.forward = self.forward_tpu
|
||||||
else:
|
else:
|
||||||
self.forward = self.forward_native
|
self.forward = self.forward_native
|
||||||
|
|
||||||
@ -105,8 +112,19 @@ class TopKTopPSampler(nn.Module):
|
|||||||
k: Optional[torch.Tensor],
|
k: Optional[torch.Tensor],
|
||||||
p: Optional[torch.Tensor],
|
p: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# TODO Placeholder for TPU optimized topk/p kernel
|
# If only top-k is specified, use pytorch's builtin topk op. This leads
|
||||||
# logits = apply_top_k_top_p(logits, k, p)
|
# to significant speed up on TPU compared to using apply_top_k_top_p.
|
||||||
|
if k is not None and p is None:
|
||||||
|
topk_values, topk_indices = torch.topk(logits, k, dim=-1)
|
||||||
|
|
||||||
|
mask = torch.ones_like(logits, dtype=torch.bool)
|
||||||
|
mask.scatter_(-1, topk_indices, False)
|
||||||
|
logits.masked_fill_(mask, float('-inf'))
|
||||||
|
else:
|
||||||
|
# TODO Placeholder for TPU optimized topp kernel
|
||||||
|
# logits = apply_top_k_top_p(logits, k, p)
|
||||||
|
pass
|
||||||
|
|
||||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||||
return random_sample(probs, generators)
|
return random_sample(probs, generators)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user