mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 04:07:04 +08:00
[V1][TPU] Support V1 Sampler for ragged attention (#14227)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
40828ce5fe
commit
d8c6d7d6b5
94
tests/v1/tpu/test_sampler.py
Normal file
94
tests/v1/tpu/test_sampler.py
Normal file
@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import tempfile
|
||||
from time import time
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, envs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
pytest.skip(
|
||||
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
def test_sampler_compilation(model_name: str, monkeypatch):
|
||||
"""
|
||||
Check that no recompilation happens despite changing sampling parameters.
|
||||
We can't read XLA metrics from the engine process, hence we measure time.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
|
||||
# Compiling model init may still take some time, enforce_eager to skip.
|
||||
llm = LLM(model_name,
|
||||
enforce_eager=True,
|
||||
max_num_seqs=16,
|
||||
max_model_len=1024,
|
||||
gpu_memory_utilization=0.5)
|
||||
prompts = [
|
||||
"A robot may not injure a human being",
|
||||
"It is only with the heart that one can see rightly;",
|
||||
]
|
||||
# First inference should be slow
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
# top_p=0.6, # TODO too slow!
|
||||
# top_k=10,
|
||||
min_p=0.2,
|
||||
max_tokens=16)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run1 = time() - s
|
||||
|
||||
# Second request with different params, but for which we
|
||||
# compiled for in previous eager iteration.
|
||||
sampling_params = SamplingParams(temperature=0.1,
|
||||
min_p=0.8,
|
||||
max_tokens=24)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run2 = time() - s
|
||||
# Much faster after compiling
|
||||
assert run1 * 0.1 > run2
|
||||
print("TIMES", run1, run2)
|
||||
|
||||
# Third request with min_p set to "None". It will not trigger
|
||||
# recompilation as a default 0 value will be used.
|
||||
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
|
||||
s = time()
|
||||
_ = llm.generate(prompts, sampling_params)
|
||||
run3 = time() - s
|
||||
assert run1 * 0.1 > run3
|
||||
print("TIMES", run1, run3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
|
||||
@pytest.mark.skipif(not current_platform.is_tpu(),
|
||||
reason="This test needs a TPU")
|
||||
def test_sampler_different(model_name: str):
|
||||
"""
|
||||
Test significantly different sampling params to assert the model produces
|
||||
different results.
|
||||
"""
|
||||
llm = LLM(
|
||||
model_name,
|
||||
enforce_eager=True,
|
||||
max_num_seqs=1,
|
||||
max_model_len=64,
|
||||
# TODO: setting to 0.5 or it will go OOM
|
||||
gpu_memory_utilization=0.5)
|
||||
prompts = [
|
||||
"Write a short story about a robot that dreams for the first time."
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
|
||||
output = llm.generate(prompts, sampling_params)
|
||||
|
||||
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
assert output[0].outputs[0].text != output2[0].outputs[0].text
|
||||
@ -65,6 +65,8 @@ class TopKTopPSampler(nn.Module):
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
@ -96,6 +98,18 @@ class TopKTopPSampler(nn.Module):
|
||||
return random_sample(probs, generators)
|
||||
return flashinfer_sample(probs, k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# TODO Placeholder for TPU optimized topk/p kernel
|
||||
# logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
@ -112,7 +126,7 @@ def apply_top_k_top_p(
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long)
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
|
||||
0
vllm/v1/sample/tpu/__init__.py
Normal file
0
vllm/v1/sample/tpu/__init__.py
Normal file
159
vllm/v1/sample/tpu/metadata.py
Normal file
159
vllm/v1/sample/tpu/metadata.py
Normal file
@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class TPUSupportedSamplingMetadata:
|
||||
# This class exposes a more xla-friendly interface than SamplingMetadata
|
||||
# on TPU, in particular all arguments should be traceable and no optionals
|
||||
# are allowed, to avoid graph recompilation on Nones.
|
||||
temperature: torch.Tensor
|
||||
|
||||
min_p: torch.Tensor
|
||||
# Still too slow on forward_native!
|
||||
top_k: torch.Tensor = None
|
||||
top_p: torch.Tensor = None
|
||||
|
||||
# XLA-unfriendly control flow in Sampler
|
||||
all_greedy: bool = False
|
||||
all_random: bool = False
|
||||
# Greedy sampling flag for compiling single xla graph.
|
||||
do_argmax: torch.Tensor = None
|
||||
|
||||
# speculation not supported
|
||||
spec_token_ids = None
|
||||
|
||||
# Generator not supported by xla
|
||||
generators: dict[int,
|
||||
torch.Generator] = field(default_factory=lambda: dict())
|
||||
|
||||
# unsupported, you need to return an extra tensor of static size BxV
|
||||
max_num_logprobs = None
|
||||
|
||||
# TODO No penalties for now
|
||||
no_penalties: bool = True
|
||||
prompt_token_ids = None
|
||||
frequency_penalties = None
|
||||
presence_penalties = None
|
||||
repetition_penalties = None
|
||||
# should use tensor
|
||||
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
|
||||
|
||||
min_tokens = None # impl is not vectorized
|
||||
|
||||
logit_bias: list[Optional[dict[int, float]]] = field(
|
||||
default_factory=lambda: list())
|
||||
|
||||
allowed_token_ids_mask = None
|
||||
bad_words_token_ids = None
|
||||
indices_do_sample: torch.Tensor = None
|
||||
|
||||
def __post_init__(self):
|
||||
temp = self.temperature
|
||||
if self.indices_do_sample is None:
|
||||
self.indices_do_sample = torch.zeros(temp.shape[0],
|
||||
device=temp.device,
|
||||
dtype=torch.int32)
|
||||
if self.do_argmax is None:
|
||||
self.do_argmax = torch.tensor(0,
|
||||
dtype=torch.bool,
|
||||
device=temp.device)
|
||||
|
||||
@classmethod
|
||||
def from_sampling_metadata(
|
||||
cls, metadata: SamplingMetadata,
|
||||
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
|
||||
device: torch.device) -> "TPUSupportedSamplingMetadata":
|
||||
"""
|
||||
Create an XLA-frienly SamplingMetadata structure. Do so by first
|
||||
instantiating an object with fixed-sized tensors and then writing the
|
||||
values in input `metadata`. Do that only for non-None values so that
|
||||
recompilation is not triggered for optional values (None/torch.Tensor).
|
||||
|
||||
In order to handle different sizes for the params that range from 1 up
|
||||
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
|
||||
Same thing for `padded_do_sample_indices`, which contains the indices
|
||||
to be fed to the Sampler, padded to the closest pre-compiled shape.
|
||||
|
||||
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
|
||||
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
|
||||
"""
|
||||
metadata = cls._validate_sampling_metadata(metadata)
|
||||
# NOTE we have to initialize default tensor-based params first and
|
||||
# skip None values altogether to produce the same xla graph.
|
||||
num_samples = len(padded_do_sample_indices)
|
||||
do_argmax = torch.tensor(metadata.all_greedy,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
new_metadata = cls.get_default_sampling_params(num_samples, device,
|
||||
indices_do_sample=\
|
||||
padded_do_sample_indices,
|
||||
do_argmax=do_argmax
|
||||
)
|
||||
supported_params = \
|
||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||
# Copy input non-None values into `new_metadata` fixed-sized tensors.
|
||||
for p_name in supported_params:
|
||||
old_val = getattr(metadata, p_name)
|
||||
new_val = getattr(new_metadata, p_name)
|
||||
if isinstance(old_val, torch.Tensor):
|
||||
new_val[:num_do_sample] = old_val
|
||||
setattr(new_metadata, p_name, new_val)
|
||||
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
return new_metadata
|
||||
|
||||
@classmethod
|
||||
def get_default_sampling_params(
|
||||
cls,
|
||||
num_samples: int,
|
||||
device: torch.device,
|
||||
indices_do_sample=None,
|
||||
do_argmax=None) -> "TPUSupportedSamplingMetadata":
|
||||
# As sampling happens on a single traced graph, options
|
||||
# are "disabled" by having them evaluate to an Identity op.
|
||||
# Note that initialization is dependent on num_samples.
|
||||
sampling_metadata_disable_value = \
|
||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
||||
init_kwargs = dict()
|
||||
for p_name, (default_val,
|
||||
dtype) in sampling_metadata_disable_value.items():
|
||||
default_tensor = torch.full((num_samples, ),
|
||||
default_val,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
init_kwargs[p_name] = default_tensor
|
||||
|
||||
return cls(**init_kwargs,
|
||||
indices_do_sample=indices_do_sample,
|
||||
do_argmax=do_argmax)
|
||||
|
||||
@staticmethod
|
||||
def _validate_sampling_metadata(
|
||||
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
|
||||
if sampling_metadata.all_greedy:
|
||||
# Set to None since #13587. Make sure default isn't overruled.
|
||||
assert sampling_metadata.temperature is None
|
||||
return sampling_metadata
|
||||
|
||||
@staticmethod
|
||||
def _get_default_params_values():
|
||||
return dict(
|
||||
# Since #13587 greedy sampling requires branching off which leads
|
||||
# to separate graphs. We set temp to noop and handle argmax here.
|
||||
temperature=(1.0, torch.float32),
|
||||
min_p=(0.0, torch.float32),
|
||||
# strictly disabled for now
|
||||
# top_k=(-1, torch.int32),
|
||||
# top_p=(0.0, torch.float32),
|
||||
# frequency_penalties=(0.0, torch.float32),
|
||||
# presence_penalties=(0.0, torch.float32),
|
||||
# repetition_penalties=(0.0, torch.float32),
|
||||
)
|
||||
154
vllm/v1/sample/tpu/sampler.py
Normal file
154
vllm/v1/sample/tpu/sampler.py
Normal file
@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Sampler layer implementing TPU supported operations."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.topk_topp_sampler = TopKTopPSampler()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
# NOTE(woosuk): Use the original logits (before any penalties or
|
||||
# temperature scaling) for the top-k logprobs.
|
||||
# This is different from the V0 sampler, which uses the logits that
|
||||
# is used for sampling (after penalties and temperature scaling).
|
||||
|
||||
# Use float32 for the logits.
|
||||
logits = logits.to(torch.float32)
|
||||
# Sample the next token.
|
||||
sampled = self.sample(logits, sampling_metadata)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
sampled = sampled.to(torch.int32)
|
||||
|
||||
# These are GPU tensors.
|
||||
sampler_output = SamplerOutput(
|
||||
# The sampled tokens are expanded to 2D tensor with shape
|
||||
# [num_requests, 1], where each row represents one generated
|
||||
# token per request.
|
||||
sampled_token_ids=sampled.unsqueeze(-1),
|
||||
logprobs_tensors=None,
|
||||
)
|
||||
return sampler_output
|
||||
|
||||
def apply_temperature(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
temp: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Use in-place division to avoid creating a new tensor.
|
||||
return logits.div_(temp.unsqueeze(dim=1))
|
||||
|
||||
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.argmax(dim=-1).view(-1)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
greedy_sampled = self.greedy_sample(logits)
|
||||
|
||||
assert sampling_metadata.temperature is not None
|
||||
|
||||
# Apply temperature.
|
||||
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
||||
|
||||
# Apply min_p.
|
||||
if sampling_metadata.min_p is not None:
|
||||
logits = self.apply_min_p(logits, sampling_metadata.min_p)
|
||||
|
||||
# Apply top_k and/or top_p.
|
||||
random_sampled = self.topk_topp_sampler(
|
||||
logits,
|
||||
sampling_metadata.generators,
|
||||
sampling_metadata.top_k,
|
||||
sampling_metadata.top_p,
|
||||
)
|
||||
|
||||
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
|
||||
greedy_sampled, random_sampled)
|
||||
return sampled
|
||||
|
||||
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
||||
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
def gather_logprobs(
|
||||
self,
|
||||
logprobs: torch.Tensor,
|
||||
num_logprobs: int,
|
||||
token_ids: torch.Tensor,
|
||||
) -> LogprobsTensors:
|
||||
"""
|
||||
Gather logprobs for topk and sampled/prompt token.
|
||||
|
||||
Args:
|
||||
logits: (num tokens) x (vocab) tensor
|
||||
num_logprobs: minimum number of logprobs to
|
||||
retain per token
|
||||
token_ids: prompt tokens (if prompt logprobs)
|
||||
or sampled tokens (if sampled
|
||||
logprobs); 1D token ID tensor
|
||||
with (num tokens) elements
|
||||
|
||||
Returns:
|
||||
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
||||
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
||||
Sampled token rank tensor, (num tokens)
|
||||
"""
|
||||
# Find the topK values.
|
||||
topk_logprobs, topk_indices = torch.topk(logprobs,
|
||||
num_logprobs,
|
||||
dim=-1)
|
||||
|
||||
# Get with the logprob of the prompt or sampled token.
|
||||
token_ids = token_ids.unsqueeze(-1)
|
||||
token_logprobs = logprobs.gather(-1, token_ids)
|
||||
|
||||
# Compute the ranks of the actual token.
|
||||
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
||||
|
||||
# Concatenate together with the topk.
|
||||
indices = torch.cat((token_ids, topk_indices), dim=1)
|
||||
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
||||
|
||||
# Use int32 to reduce the tensor size.
|
||||
indices = indices.to(torch.int32)
|
||||
|
||||
return LogprobsTensors(indices, logprobs, token_ranks)
|
||||
|
||||
def apply_min_p(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
min_p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Filters logits using adaptive probability thresholding.
|
||||
"""
|
||||
# Convert logits to probability distribution
|
||||
probability_values = torch.nn.functional.softmax(logits, dim=-1)
|
||||
# Calculate maximum probabilities per sequence
|
||||
max_probabilities = torch.amax(probability_values,
|
||||
dim=-1,
|
||||
keepdim=True)
|
||||
# Reshape min_p for broadcasting
|
||||
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
|
||||
# Identify valid tokens using threshold comparison
|
||||
valid_token_mask = probability_values >= adjusted_min_p
|
||||
# Apply mask using boolean indexing (xla friendly)
|
||||
logits.masked_fill_(~valid_token_mask, -float("inf"))
|
||||
return logits
|
||||
@ -23,13 +23,16 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
PallasAttentionBackend,
|
||||
PallasMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
ModelRunnerOutput, SamplerOutput)
|
||||
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
|
||||
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
@ -42,6 +45,8 @@ logger = init_logger(__name__)
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
INVALID_TOKEN_ID = -1
|
||||
# Smallest output size
|
||||
MIN_NUM_SEQS = 8
|
||||
|
||||
|
||||
class TPUModelRunner:
|
||||
@ -138,8 +143,10 @@ class TPUModelRunner:
|
||||
device="cpu")
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
|
||||
padded_max_num_blocks_per_req = _get_padded_number(
|
||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
device="cpu")
|
||||
|
||||
@ -267,6 +274,9 @@ class TPUModelRunner:
|
||||
req_data.num_computed_tokens)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
# Check if the batch has changed. If not, we can skip copying the
|
||||
# sampling metadata from CPU to GPU.
|
||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
@ -284,6 +294,10 @@ class TPUModelRunner:
|
||||
# Condense the batched states if there are empty indices.
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
# TODO This slices tensors to copy to device, triggering recompilation.
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
@ -447,6 +461,8 @@ class TPUModelRunner:
|
||||
# TODO: Support prompt logprobs.
|
||||
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs, self.max_num_reqs)
|
||||
# Indices at which we sample (positions of last token in the sequence).
|
||||
# Padded to avoid recompiling when `num_reqs` varies.
|
||||
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
|
||||
logits_indices = logits_indices.to(self.device)
|
||||
return attn_metadata, logits_indices
|
||||
@ -576,7 +592,14 @@ class TPUModelRunner:
|
||||
# then the embedding layer is not included in the CUDA graph.
|
||||
input_ids = self.input_ids
|
||||
inputs_embeds = None
|
||||
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
||||
# mismatch in pre-processing, it will trigger a small recompilation
|
||||
# of the code thus far. Forward graph remains untouched.
|
||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(sampling_metadata, logits_indices,
|
||||
num_reqs, self.device)
|
||||
# Run the decoder
|
||||
with set_forward_context(attn_metadata, self.vllm_config):
|
||||
hidden_states = self.model(
|
||||
@ -585,12 +608,13 @@ class TPUModelRunner:
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
selected_token_ids = self.model.compute_logits(hidden_states,
|
||||
logits_indices, None)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
# Then, let's update the cache state.
|
||||
# Update the cache state concurrently. Code above will not block until
|
||||
# we use `selected_token_ids`. Add mark_step if post-processing changes
|
||||
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
|
||||
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
|
||||
assert req_id is not None
|
||||
@ -607,7 +631,6 @@ class TPUModelRunner:
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
# num_reqs entries should be non-None
|
||||
assert all(
|
||||
req_id is not None for req_id in
|
||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
|
||||
@ -620,6 +643,7 @@ class TPUModelRunner:
|
||||
max_gen_len = selected_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_sampled_token_ids = selected_token_ids.tolist()
|
||||
|
||||
for i, req_state, seq_len in request_seq_lens:
|
||||
token_id = valid_sampled_token_ids[i][0]
|
||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id
|
||||
@ -676,11 +700,8 @@ class TPUModelRunner:
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
kv_caches,
|
||||
num_tokens: int,
|
||||
) -> None:
|
||||
@torch.no_grad()
|
||||
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||
@ -729,32 +750,10 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
64, self.max_num_reqs)
|
||||
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
|
||||
# compilation cache size of token_bucket_num multiplied by
|
||||
# req_bucket_num. This is acceptable, given the graph's relatively
|
||||
# small size.
|
||||
while True:
|
||||
logits_indices = torch.zeros(
|
||||
num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
torch._dynamo.mark_dynamic(hidden_states, 0)
|
||||
torch._dynamo.mark_dynamic(logits_indices, 0)
|
||||
self.model.compute_logits(hidden_states, logits_indices, None)
|
||||
if num_reqs >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs = _get_padded_num_reqs_with_upper_limit(
|
||||
num_reqs + 1, self.max_num_reqs)
|
||||
self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@ -764,13 +763,51 @@ class TPUModelRunner:
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
while True:
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
logger.info("Compiling sampling with different input shapes.")
|
||||
start = time.perf_counter()
|
||||
num_tokens = 16
|
||||
hsize = self.model_config.get_hidden_size()
|
||||
device = self.device
|
||||
# Compile sampling step for different model+sampler outputs in bucketed
|
||||
# n_tokens x max_num_reqs. Graph is really small so this is fine.
|
||||
while True:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
while True:
|
||||
# Default metadata is an all_greedy setup. But since the
|
||||
# `do_argmax` flag is a tensor, we still compile the full graph
|
||||
meta = self.input_batch.sampling_metadata
|
||||
indices = torch.zeros(
|
||||
num_reqs_to_sample,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||
from_sampling_metadata(meta, indices,
|
||||
num_reqs_to_sample, device)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
xm.mark_step()
|
||||
if num_reqs_to_sample >= self.max_num_reqs:
|
||||
break
|
||||
num_reqs_to_sample *= 2
|
||||
if num_tokens >= self.max_num_tokens:
|
||||
break
|
||||
num_tokens *= 2
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
|
||||
@ -818,6 +855,13 @@ class ModelWrapperV1(nn.Module):
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
def sample(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
||||
sampler_out = self.sampler(logits, sampling_metadata)
|
||||
return sampler_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -826,7 +870,7 @@ class ModelWrapperV1(nn.Module):
|
||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
"""Executes the forward pass of the model.
|
||||
|
||||
Args:
|
||||
input_ids: The input token IDs of shape [num_tokens].
|
||||
@ -837,7 +881,6 @@ class ModelWrapperV1(nn.Module):
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
|
||||
assert self.model is not None
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@ -846,17 +889,33 @@ class ModelWrapperV1(nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||
def compute_logits(
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
logits_indices: torch.Tensor,
|
||||
sampling_metadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
return selected_token_ids
|
||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
"""
|
||||
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
logits = self.compute_logits(sample_hidden_states)
|
||||
# Greedy sampling can't be run without branching the graph on Sampler.
|
||||
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
||||
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.do_argmax,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
self.sample(logits, sampling_metadata)\
|
||||
.sampled_token_ids)
|
||||
return out_tokens
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
return logits
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
@ -876,5 +935,5 @@ def _get_padded_token_len(x: int) -> int:
|
||||
|
||||
|
||||
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
|
||||
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
|
||||
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
|
||||
return min(res, upper_limit)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user