[TPU][V1] Fix Sampler recompilation (#15309)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-03-25 21:43:54 +01:00 committed by GitHub
parent e977c11111
commit a0dd7dcd49
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 126 deletions

View File

@ -5,7 +5,18 @@ from typing import Optional
import torch import torch
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch
DEFAULT_SAMPLING_PARAMS = dict(
temperature=-1.0,
min_p=0.0,
# strictly disabled for now
# top_k=-1,
# top_p=0.0,
# frequency_penalties=0.0,
# presence_penalties=0.0,
# repetition_penalties=0.0,
)
@dataclass @dataclass
@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
top_k: torch.Tensor = None top_k: torch.Tensor = None
top_p: torch.Tensor = None top_p: torch.Tensor = None
# XLA-unfriendly control flow in Sampler
all_greedy: bool = False
all_random: bool = False
# Greedy sampling flag for compiling single xla graph. # Greedy sampling flag for compiling single xla graph.
do_argmax: torch.Tensor = None all_greedy: torch.Tensor = None
# speculation not supported
spec_token_ids = None
# Generator not supported by xla # Generator not supported by xla
generators: dict[int, generators: dict[int,
@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
bad_words_token_ids = None bad_words_token_ids = None
indices_do_sample: torch.Tensor = None indices_do_sample: torch.Tensor = None
def __post_init__(self):
temp = self.temperature
if self.indices_do_sample is None:
self.indices_do_sample = torch.zeros(temp.shape[0],
device=temp.device,
dtype=torch.int32)
if self.do_argmax is None:
self.do_argmax = torch.tensor(0,
dtype=torch.bool,
device=temp.device)
@classmethod @classmethod
def from_sampling_metadata( def from_input_batch(
cls, metadata: SamplingMetadata, cls, input_batch: InputBatch,
padded_do_sample_indices: torch.Tensor, num_do_sample: int, indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
device: torch.device) -> "TPUSupportedSamplingMetadata":
""" """
Create an XLA-frienly SamplingMetadata structure. Do so by first Copy sampling tensors slices from `input_batch` to on device tensors.
instantiating an object with fixed-sized tensors and then writing the
values in input `metadata`. Do that only for non-None values so that
recompilation is not triggered for optional values (None/torch.Tensor).
In order to handle different sizes for the params that range from 1 up
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
Same thing for `padded_do_sample_indices`, which contains the indices
to be fed to the Sampler, padded to the closest pre-compiled shape.
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0] `InputBatch._make_sampling_metadata` causes recompilation on XLA as it
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0] slices dynamic shapes on device tensors. This impl moves the dynamic
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
also reuses the on-device persistent tensors managed in `input_batch`
to reduce waste.
`indices_do_sample` contains the indices to be fed to the Sampler,
normally one per request, here padded to the closest pre-compiled shape
We expect sampling params tensors to be padded to the same fixed shape.
Eg. 3 requests, tensors padded to 4
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
""" """
metadata = cls._validate_sampling_metadata(metadata) num_reqs = input_batch.num_reqs
# NOTE we have to initialize default tensor-based params first and padded_num_reqs = len(indices_do_sample)
# skip None values altogether to produce the same xla graph.
num_samples = len(padded_do_sample_indices) def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
do_argmax = torch.tensor(metadata.all_greedy, fill_val) -> torch.Tensor:
dtype=torch.bool, # Copy slice from CPU to corresponding TPU pre-allocated tensor.
device=device) # Pad value is the default one.
new_metadata = cls.get_default_sampling_params(num_samples, device, cpu_tensor[num_reqs:padded_num_reqs] = fill_val
indices_do_sample=\ tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
padded_do_sample_indices,
do_argmax=do_argmax # NOTE NickLucche The sync CPU-TPU graph we produce here must be
) # consistent. We can't have flags to skip copies or we'll end up
supported_params = \ # recompiling.
TPUSupportedSamplingMetadata._get_default_params_values() copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
# Copy input non-None values into `new_metadata` fixed-sized tensors. DEFAULT_SAMPLING_PARAMS["temperature"])
for p_name in supported_params: # TODO Temporarily disabled until sampling options are enabled
old_val = getattr(metadata, p_name) # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
new_val = getattr(new_metadata, p_name) # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
if isinstance(old_val, torch.Tensor): copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
new_val[:num_do_sample] = old_val DEFAULT_SAMPLING_PARAMS["min_p"])
setattr(new_metadata, p_name, new_val)
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
# input_batch.frequency_penalties)
# copy_slice(input_batch.presence_penalties_cpu_tensor,
# input_batch.presence_penalties)
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
# input_batch.repetition_penalties)
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
return new_metadata
@classmethod # Slice persistent device tensors to a fixed pre-compiled padded shape.
def get_default_sampling_params( return cls(
cls, temperature=input_batch.temperature[:padded_num_reqs],
num_samples: int, # Scalar tensor for xla-friendly tracing.
device: torch.device, all_greedy=torch.tensor(input_batch.all_greedy,
indices_do_sample=None, dtype=torch.bool,
do_argmax=None) -> "TPUSupportedSamplingMetadata": device=input_batch.device),
# As sampling happens on a single traced graph, options # TODO enable more and avoid returning None values
# are "disabled" by having them evaluate to an Identity op. top_p=None, # input_batch.top_p[:padded_num_reqs],
# Note that initialization is dependent on num_samples. top_k=None, # input_batch.top_k[:padded_num_reqs],
sampling_metadata_disable_value = \ min_p=input_batch.min_p[:padded_num_reqs],
TPUSupportedSamplingMetadata._get_default_params_values() generators=input_batch.generators,
init_kwargs = dict() indices_do_sample=indices_do_sample)
for p_name, (default_val,
dtype) in sampling_metadata_disable_value.items():
default_tensor = torch.full((num_samples, ),
default_val,
dtype=dtype,
device=device)
init_kwargs[p_name] = default_tensor
return cls(**init_kwargs,
indices_do_sample=indices_do_sample,
do_argmax=do_argmax)
@staticmethod
def _validate_sampling_metadata(
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
if sampling_metadata.all_greedy:
# Set to None since #13587. Make sure default isn't overruled.
assert sampling_metadata.temperature is None
return sampling_metadata
@staticmethod
def _get_default_params_values():
return dict(
# Since #13587 greedy sampling requires branching off which leads
# to separate graphs. We set temp to noop and handle argmax here.
temperature=(1.0, torch.float32),
min_p=(0.0, torch.float32),
# strictly disabled for now
# top_k=(-1, torch.int32),
# top_p=(0.0, torch.float32),
# frequency_penalties=(0.0, torch.float32),
# presence_penalties=(0.0, torch.float32),
# repetition_penalties=(0.0, torch.float32),
)

View File

@ -279,9 +279,6 @@ class TPUModelRunner:
req_data.num_computed_tokens) req_data.num_computed_tokens)
self.input_batch.block_table.append_row(req_data.new_block_ids, self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index) req_index)
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
# Add the new or resumed requests to the persistent batch. # Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
@ -300,9 +297,6 @@ class TPUModelRunner:
if removed_req_indices: if removed_req_indices:
self.input_batch.condense(removed_req_indices) self.input_batch.condense(removed_req_indices)
# TODO This slices tensors to copy to device, triggering recompilation.
if batch_changed:
self.input_batch.refresh_sampling_metadata()
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0 return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
def get_model(self) -> nn.Module: def get_model(self) -> nn.Module:
@ -597,14 +591,12 @@ class TPUModelRunner:
# then the embedding layer is not included in the CUDA graph. # then the embedding layer is not included in the CUDA graph.
input_ids = self.input_ids input_ids = self.input_ids
inputs_embeds = None inputs_embeds = None
sampling_metadata = self.input_batch.sampling_metadata
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
# NOTE (NickLucche) here we sync with TPU: if there's any shape # NOTE (NickLucche) here we sync with TPU: sampling params tensors
# mismatch in pre-processing, it will trigger a small recompilation # are copied to device in chunks of pre-compiled padded shape to
# of the code thus far. Forward graph remains untouched. # avoid recompilations.
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
from_sampling_metadata(sampling_metadata, logits_indices, from_input_batch(self.input_batch, logits_indices)
num_reqs, self.device)
# Run the decoder # Run the decoder
with set_forward_context(attn_metadata, self.vllm_config): with set_forward_context(attn_metadata, self.vllm_config):
hidden_states = self.model( hidden_states = self.model(
@ -797,21 +789,19 @@ class TPUModelRunner:
device=device, device=device,
dtype=torch.bfloat16) dtype=torch.bfloat16)
while True: while True:
# Default metadata is an all_greedy setup. But since the
# `do_argmax` flag is a tensor, we still compile the full graph
meta = self.input_batch.sampling_metadata
indices = torch.zeros( indices = torch.zeros(
num_reqs_to_sample, num_reqs_to_sample,
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
xm.mark_step()
sampling_meta = TPUSupportedSamplingMetadata.\ sampling_meta = TPUSupportedSamplingMetadata.\
from_sampling_metadata(meta, indices, from_input_batch(self.input_batch, indices)
num_reqs_to_sample, device)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample) num_reqs_to_sample)
self.model.sample_from_hidden(dummy_hidden, sampling_meta) out = self.model.sample_from_hidden(dummy_hidden,
xm.mark_step() sampling_meta)
out = out.cpu()
if num_reqs_to_sample >= self.max_num_reqs: if num_reqs_to_sample >= self.max_num_reqs:
break break
num_reqs_to_sample *= 2 num_reqs_to_sample *= 2
@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
return hidden_states return hidden_states
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
def sample_from_hidden( def sample_from_hidden(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
sample_hidden_states = \ sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample] hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states) logits = self.compute_logits(sample_hidden_states)
# Greedy sampling can't be run without branching the graph on Sampler. # Optimized greedy sampling branch, tracing both paths in a single pass
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way. # NOTE all_greedy is a scalar, this is just an optimized if/else.
# NOTE do_argmax is a scalar, this is just an optimized if/else. out_tokens = torch.where(sampling_metadata.all_greedy,
out_tokens = torch.where(sampling_metadata.do_argmax,
torch.argmax(logits, dim=-1, keepdim=True), torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\ self.sample(logits, sampling_metadata)\
.sampled_token_ids) .sampled_token_ids)