[V1][TPU] Support V1 Sampler for ragged attention (#14227)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-03-20 05:00:39 +01:00 committed by GitHub
parent 40828ce5fe
commit d8c6d7d6b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 535 additions and 55 deletions

View File

@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
import tempfile
from time import time
import pytest
from vllm import LLM, envs
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
if not envs.VLLM_USE_V1:
pytest.skip(
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
allow_module_level=True,
)
@pytest.mark.parametrize("model_name", ["D4nt3/Qwen2.5-two-layers"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_compilation(model_name: str, monkeypatch):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with tempfile.TemporaryDirectory() as temp_dir:
monkeypatch.setenv("VLLM_XLA_CACHE_PATH", temp_dir)
# Compiling model init may still take some time, enforce_eager to skip.
llm = LLM(model_name,
enforce_eager=True,
max_num_seqs=16,
max_model_len=1024,
gpu_memory_utilization=0.5)
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
]
# First inference should be slow
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
# top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
_ = llm.generate(prompts, sampling_params)
run1 = time() - s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
min_p=0.8,
max_tokens=24)
s = time()
_ = llm.generate(prompts, sampling_params)
run2 = time() - s
# Much faster after compiling
assert run1 * 0.1 > run2
print("TIMES", run1, run2)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params = SamplingParams(max_tokens=24, temperature=0.0)
s = time()
_ = llm.generate(prompts, sampling_params)
run3 = time() - s
assert run1 * 0.1 > run3
print("TIMES", run1, run3)
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
@pytest.mark.skipif(not current_platform.is_tpu(),
reason="This test needs a TPU")
def test_sampler_different(model_name: str):
"""
Test significantly different sampling params to assert the model produces
different results.
"""
llm = LLM(
model_name,
enforce_eager=True,
max_num_seqs=1,
max_model_len=64,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization=0.5)
prompts = [
"Write a short story about a robot that dreams for the first time."
]
sampling_params = SamplingParams(temperature=0.9, min_p=0.2, max_tokens=64)
output = llm.generate(prompts, sampling_params)
sampling_params = SamplingParams(temperature=0.1, min_p=0.8, max_tokens=64)
output2 = llm.generate(prompts, sampling_params)
assert output[0].outputs[0].text != output2[0].outputs[0].text

View File

@ -65,6 +65,8 @@ class TopKTopPSampler(nn.Module):
"native implementation of top-p & top-k sampling. For the "
"best performance, please install FlashInfer.")
self.forward = self.forward_native
elif current_platform.is_tpu():
self.forward = self.forward_tpu
else:
self.forward = self.forward_native
@ -96,6 +98,18 @@ class TopKTopPSampler(nn.Module):
return random_sample(probs, generators)
return flashinfer_sample(probs, k, p, generators)
def forward_tpu(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
) -> torch.Tensor:
# TODO Placeholder for TPU optimized topk/p kernel
# logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
def apply_top_k_top_p(
logits: torch.Tensor,
@ -112,7 +126,7 @@ def apply_top_k_top_p(
if k is not None:
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask

View File

View File

@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata
@dataclass
class TPUSupportedSamplingMetadata:
# This class exposes a more xla-friendly interface than SamplingMetadata
# on TPU, in particular all arguments should be traceable and no optionals
# are allowed, to avoid graph recompilation on Nones.
temperature: torch.Tensor
min_p: torch.Tensor
# Still too slow on forward_native!
top_k: torch.Tensor = None
top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla
generators: dict[int,
torch.Generator] = field(default_factory=lambda: dict())
# unsupported, you need to return an extra tensor of static size BxV
max_num_logprobs = None
# TODO No penalties for now
no_penalties: bool = True
prompt_token_ids = None
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
# should use tensor
output_token_ids: list[list[int]] = field(default_factory=lambda: list())
min_tokens = None # impl is not vectorized
logit_bias: list[Optional[dict[int, float]]] = field(
default_factory=lambda: list())
allowed_token_ids_mask = None
bad_words_token_ids = None
indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod
def from_sampling_metadata(
cls, metadata: SamplingMetadata,
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
device: torch.device) -> "TPUSupportedSamplingMetadata":
"""
Create an XLA-frienly SamplingMetadata structure. Do so by first
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
"""
metadata = cls._validate_sampling_metadata(metadata)
# NOTE we have to initialize default tensor-based params first and
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices)
do_argmax = torch.tensor(metadata.all_greedy,
dtype=torch.bool,
device=device)
new_metadata = cls.get_default_sampling_params(num_samples, device,
indices_do_sample=\
padded_do_sample_indices,
do_argmax=do_argmax
)
supported_params = \
TPUSupportedSamplingMetadata._get_default_params_values()
# Copy input non-None values into `new_metadata` fixed-sized tensors.
for p_name in supported_params:
old_val = getattr(metadata, p_name)
new_val = getattr(new_metadata, p_name)
if isinstance(old_val, torch.Tensor):
new_val[:num_do_sample] = old_val
setattr(new_metadata, p_name, new_val)
xm.mark_step()
xm.wait_device_ops()
return new_metadata
@classmethod
def get_default_sampling_params(
cls,
num_samples: int,
device: torch.device,
indices_do_sample=None,
do_argmax=None) -> "TPUSupportedSamplingMetadata":
# As sampling happens on a single traced graph, options
# are "disabled" by having them evaluate to an Identity op.
# Note that initialization is dependent on num_samples.
sampling_metadata_disable_value = \
TPUSupportedSamplingMetadata._get_default_params_values()
init_kwargs = dict()
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)

View File

@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
"""Sampler layer implementing TPU supported operations."""
import torch
import torch.nn as nn
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
_SAMPLING_EPS = 1e-5
class Sampler(nn.Module):
def __init__(self):
super().__init__()
self.topk_topp_sampler = TopKTopPSampler()
def forward(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> SamplerOutput:
# NOTE(woosuk): Use the original logits (before any penalties or
# temperature scaling) for the top-k logprobs.
# This is different from the V0 sampler, which uses the logits that
# is used for sampling (after penalties and temperature scaling).
# Use float32 for the logits.
logits = logits.to(torch.float32)
# Sample the next token.
sampled = self.sample(logits, sampling_metadata)
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)
# These are GPU tensors.
sampler_output = SamplerOutput(
# The sampled tokens are expanded to 2D tensor with shape
# [num_requests, 1], where each row represents one generated
# token per request.
sampled_token_ids=sampled.unsqueeze(-1),
logprobs_tensors=None,
)
return sampler_output
def apply_temperature(
self,
logits: torch.Tensor,
temp: torch.Tensor,
) -> torch.Tensor:
# Use in-place division to avoid creating a new tensor.
return logits.div_(temp.unsqueeze(dim=1))
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
return logits.argmax(dim=-1).view(-1)
def sample(
self,
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
greedy_sampled = self.greedy_sample(logits)
assert sampling_metadata.temperature is not None
# Apply temperature.
logits = self.apply_temperature(logits, sampling_metadata.temperature)
# Apply min_p.
if sampling_metadata.min_p is not None:
logits = self.apply_min_p(logits, sampling_metadata.min_p)
# Apply top_k and/or top_p.
random_sampled = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
)
sampled = torch.where(sampling_metadata.temperature < _SAMPLING_EPS,
greedy_sampled, random_sampled)
return sampled
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
return logits.log_softmax(dim=-1, dtype=torch.float32)
def gather_logprobs(
self,
logprobs: torch.Tensor,
num_logprobs: int,
token_ids: torch.Tensor,
) -> LogprobsTensors:
"""
Gather logprobs for topk and sampled/prompt token.
Args:
logits: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
or sampled tokens (if sampled
logprobs); 1D token ID tensor
with (num tokens) elements
Returns:
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
Sampled token rank tensor, (num tokens)
"""
# Find the topK values.
topk_logprobs, topk_indices = torch.topk(logprobs,
num_logprobs,
dim=-1)
# Get with the logprob of the prompt or sampled token.
token_ids = token_ids.unsqueeze(-1)
token_logprobs = logprobs.gather(-1, token_ids)
# Compute the ranks of the actual token.
token_ranks = (logprobs >= token_logprobs).sum(-1)
# Concatenate together with the topk.
indices = torch.cat((token_ids, topk_indices), dim=1)
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
# Use int32 to reduce the tensor size.
indices = indices.to(torch.int32)
return LogprobsTensors(indices, logprobs, token_ranks)
def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing (xla friendly)
logits.masked_fill_(~valid_token_mask, -float("inf"))
return logits

View File

@ -23,13 +23,16 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
PallasAttentionBackend,
PallasMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
ModelRunnerOutput, SamplerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
@ -42,6 +45,8 @@ logger = init_logger(__name__)
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
_PAD_SLOT_ID = 1_000_000_000
INVALID_TOKEN_ID = -1
# Smallest output size
MIN_NUM_SEQS = 8
class TPUModelRunner:
@ -138,8 +143,10 @@ class TPUModelRunner:
device="cpu")
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
padded_max_num_blocks_per_req = _get_padded_number(
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
self.block_table_cpu = torch.zeros(
(self.max_num_tokens, self.max_num_blocks_per_req),
(self.max_num_tokens, padded_max_num_blocks_per_req),
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
device="cpu")
@ -267,6 +274,9 @@ class TPUModelRunner:
req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
@ -284,6 +294,10 @@ class TPUModelRunner:
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module:
@ -447,6 +461,8 @@ class TPUModelRunner:
# TODO: Support prompt logprobs.
padded_num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs, self.max_num_reqs)
# Indices at which we sample (positions of last token in the sequence).
# Padded to avoid recompiling when `num_reqs` varies.
logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1
logits_indices = logits_indices.to(self.device)
return attn_metadata, logits_indices
@ -576,7 +592,14 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids
inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape
# mismatch in pre-processing, it will trigger a small recompilation
# of the code thus far. Forward graph remains untouched.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices,
num_reqs, self.device)
# Run the decoder
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model(
@ -585,12 +608,13 @@ class TPUModelRunner:
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = self.input_batch.num_reqs
selected_token_ids = self.model.compute_logits(hidden_states,
logits_indices, None)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
# Then, let's update the cache state.
# Update the cache state concurrently. Code above will not block until
# we use `selected_token_ids`. Add mark_step if post-processing changes
request_seq_lens: list[tuple[int, CachedRequestState, int]] = []
for i, req_id in zip(range(num_reqs), self.input_batch.req_ids):
assert req_id is not None
@ -607,7 +631,6 @@ class TPUModelRunner:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
@ -620,6 +643,7 @@ class TPUModelRunner:
max_gen_len = selected_token_ids.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = selected_token_ids.tolist()
for i, req_state, seq_len in request_seq_lens:
token_id = valid_sampled_token_ids[i][0]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
@ -676,11 +700,8 @@ class TPUModelRunner:
fullgraph=True,
dynamic=False)
def _dummy_run(
self,
kv_caches,
num_tokens: int,
) -> None:
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@ -729,32 +750,10 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds,
)
num_reqs = _get_padded_num_reqs_with_upper_limit(
64, self.max_num_reqs)
# NOTE(chengjiyao): In total, the compute_logits function utilizes a
# compilation cache size of token_bucket_num multiplied by
# req_bucket_num. This is acceptable, given the graph's relatively
# small size.
while True:
logits_indices = torch.zeros(
num_reqs,
dtype=torch.int32,
device=self.device,
)
torch._dynamo.mark_dynamic(hidden_states, 0)
torch._dynamo.mark_dynamic(logits_indices, 0)
self.model.compute_logits(hidden_states, logits_indices, None)
if num_reqs >= self.max_num_reqs:
break
num_reqs = _get_padded_num_reqs_with_upper_limit(
num_reqs + 1, self.max_num_reqs)
self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
def capture_model(self) -> None:
"""Compile the model."""
@ -764,13 +763,51 @@ class TPUModelRunner:
start = time.perf_counter()
num_tokens = 16
while True:
self._dummy_run(self.kv_caches, num_tokens)
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
xm.mark_step()
xm.wait_device_ops()
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compiling sampling with different input shapes.")
start = time.perf_counter()
num_tokens = 16
hsize = self.model_config.get_hidden_size()
device = self.device
# Compile sampling step for different model+sampler outputs in bucketed
# n_tokens x max_num_reqs. Graph is really small so this is fine.
while True:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros(
num_reqs_to_sample,
dtype=torch.int32,
device=device,
)
sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices,
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
xm.mark_step()
if num_reqs_to_sample >= self.max_num_reqs:
break
num_reqs_to_sample *= 2
if num_tokens >= self.max_num_tokens:
break
num_tokens *= 2
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
@ -818,6 +855,13 @@ class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
@ -826,7 +870,7 @@ class ModelWrapperV1(nn.Module):
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model and samples the next token.
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
@ -837,7 +881,6 @@ class ModelWrapperV1(nn.Module):
hidden_size]. It is used for multimodal models.
"""
assert self.model is not None
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
@ -846,17 +889,33 @@ class ModelWrapperV1(nn.Module):
return hidden_states
@torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def compute_logits(
def sample_from_hidden(
self,
hidden_states: torch.Tensor,
logits_indices: torch.Tensor,
sampling_metadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, sampling_metadata)
selected_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
return selected_token_ids
sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler.
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
# NOTE do_argmax is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
@ -876,5 +935,5 @@ def _get_padded_token_len(x: int) -> int:
def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int:
res = 64 if x <= 64 else 1 << (x - 1).bit_length()
res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length()
return min(res, upper_limit)