mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-03 17:24:31 +08:00
[TPU][V1] Fix Sampler recompilation (#15309)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
e977c11111
commit
a0dd7dcd49
@ -5,7 +5,18 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
|
||||||
|
DEFAULT_SAMPLING_PARAMS = dict(
|
||||||
|
temperature=-1.0,
|
||||||
|
min_p=0.0,
|
||||||
|
# strictly disabled for now
|
||||||
|
# top_k=-1,
|
||||||
|
# top_p=0.0,
|
||||||
|
# frequency_penalties=0.0,
|
||||||
|
# presence_penalties=0.0,
|
||||||
|
# repetition_penalties=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -20,14 +31,8 @@ class TPUSupportedSamplingMetadata:
|
|||||||
top_k: torch.Tensor = None
|
top_k: torch.Tensor = None
|
||||||
top_p: torch.Tensor = None
|
top_p: torch.Tensor = None
|
||||||
|
|
||||||
# XLA-unfriendly control flow in Sampler
|
|
||||||
all_greedy: bool = False
|
|
||||||
all_random: bool = False
|
|
||||||
# Greedy sampling flag for compiling single xla graph.
|
# Greedy sampling flag for compiling single xla graph.
|
||||||
do_argmax: torch.Tensor = None
|
all_greedy: torch.Tensor = None
|
||||||
|
|
||||||
# speculation not supported
|
|
||||||
spec_token_ids = None
|
|
||||||
|
|
||||||
# Generator not supported by xla
|
# Generator not supported by xla
|
||||||
generators: dict[int,
|
generators: dict[int,
|
||||||
@ -54,106 +59,68 @@ class TPUSupportedSamplingMetadata:
|
|||||||
bad_words_token_ids = None
|
bad_words_token_ids = None
|
||||||
indices_do_sample: torch.Tensor = None
|
indices_do_sample: torch.Tensor = None
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
temp = self.temperature
|
|
||||||
if self.indices_do_sample is None:
|
|
||||||
self.indices_do_sample = torch.zeros(temp.shape[0],
|
|
||||||
device=temp.device,
|
|
||||||
dtype=torch.int32)
|
|
||||||
if self.do_argmax is None:
|
|
||||||
self.do_argmax = torch.tensor(0,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=temp.device)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_sampling_metadata(
|
def from_input_batch(
|
||||||
cls, metadata: SamplingMetadata,
|
cls, input_batch: InputBatch,
|
||||||
padded_do_sample_indices: torch.Tensor, num_do_sample: int,
|
indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata":
|
||||||
device: torch.device) -> "TPUSupportedSamplingMetadata":
|
|
||||||
"""
|
"""
|
||||||
Create an XLA-frienly SamplingMetadata structure. Do so by first
|
Copy sampling tensors slices from `input_batch` to on device tensors.
|
||||||
instantiating an object with fixed-sized tensors and then writing the
|
|
||||||
values in input `metadata`. Do that only for non-None values so that
|
|
||||||
recompilation is not triggered for optional values (None/torch.Tensor).
|
|
||||||
|
|
||||||
In order to handle different sizes for the params that range from 1 up
|
|
||||||
to `max_num_seqs`, pad tensors to the closest pre-compiled shape.
|
|
||||||
Same thing for `padded_do_sample_indices`, which contains the indices
|
|
||||||
to be fed to the Sampler, padded to the closest pre-compiled shape.
|
|
||||||
|
|
||||||
Eg. pad to 4 temperature: [0.7, 0.2]=>[0.7, 0.2, 0.0, 0.0]
|
`InputBatch._make_sampling_metadata` causes recompilation on XLA as it
|
||||||
do_sample_indices: [4, 10]=>padded_do_sample_indices: [4, 10, 0, 0]
|
slices dynamic shapes on device tensors. This impl moves the dynamic
|
||||||
|
ops to CPU and produces tensors of fixed `padded_num_reqs` size. It
|
||||||
|
also reuses the on-device persistent tensors managed in `input_batch`
|
||||||
|
to reduce waste.
|
||||||
|
|
||||||
|
`indices_do_sample` contains the indices to be fed to the Sampler,
|
||||||
|
normally one per request, here padded to the closest pre-compiled shape
|
||||||
|
We expect sampling params tensors to be padded to the same fixed shape.
|
||||||
|
|
||||||
|
Eg. 3 requests, tensors padded to 4
|
||||||
|
temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0]
|
||||||
|
sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0]
|
||||||
"""
|
"""
|
||||||
metadata = cls._validate_sampling_metadata(metadata)
|
num_reqs = input_batch.num_reqs
|
||||||
# NOTE we have to initialize default tensor-based params first and
|
padded_num_reqs = len(indices_do_sample)
|
||||||
# skip None values altogether to produce the same xla graph.
|
|
||||||
num_samples = len(padded_do_sample_indices)
|
def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor,
|
||||||
do_argmax = torch.tensor(metadata.all_greedy,
|
fill_val) -> torch.Tensor:
|
||||||
dtype=torch.bool,
|
# Copy slice from CPU to corresponding TPU pre-allocated tensor.
|
||||||
device=device)
|
# Pad value is the default one.
|
||||||
new_metadata = cls.get_default_sampling_params(num_samples, device,
|
cpu_tensor[num_reqs:padded_num_reqs] = fill_val
|
||||||
indices_do_sample=\
|
tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs]
|
||||||
padded_do_sample_indices,
|
|
||||||
do_argmax=do_argmax
|
# NOTE NickLucche The sync CPU-TPU graph we produce here must be
|
||||||
)
|
# consistent. We can't have flags to skip copies or we'll end up
|
||||||
supported_params = \
|
# recompiling.
|
||||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature,
|
||||||
# Copy input non-None values into `new_metadata` fixed-sized tensors.
|
DEFAULT_SAMPLING_PARAMS["temperature"])
|
||||||
for p_name in supported_params:
|
# TODO Temporarily disabled until sampling options are enabled
|
||||||
old_val = getattr(metadata, p_name)
|
# copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p)
|
||||||
new_val = getattr(new_metadata, p_name)
|
# copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k)
|
||||||
if isinstance(old_val, torch.Tensor):
|
copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p,
|
||||||
new_val[:num_do_sample] = old_val
|
DEFAULT_SAMPLING_PARAMS["min_p"])
|
||||||
setattr(new_metadata, p_name, new_val)
|
|
||||||
|
# copy_slice(input_batch.frequency_penalties_cpu_tensor,
|
||||||
|
# input_batch.frequency_penalties)
|
||||||
|
# copy_slice(input_batch.presence_penalties_cpu_tensor,
|
||||||
|
# input_batch.presence_penalties)
|
||||||
|
# copy_slice(input_batch.repetition_penalties_cpu_tensor,
|
||||||
|
# input_batch.repetition_penalties)
|
||||||
|
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
return new_metadata
|
|
||||||
|
|
||||||
@classmethod
|
# Slice persistent device tensors to a fixed pre-compiled padded shape.
|
||||||
def get_default_sampling_params(
|
return cls(
|
||||||
cls,
|
temperature=input_batch.temperature[:padded_num_reqs],
|
||||||
num_samples: int,
|
# Scalar tensor for xla-friendly tracing.
|
||||||
device: torch.device,
|
all_greedy=torch.tensor(input_batch.all_greedy,
|
||||||
indices_do_sample=None,
|
dtype=torch.bool,
|
||||||
do_argmax=None) -> "TPUSupportedSamplingMetadata":
|
device=input_batch.device),
|
||||||
# As sampling happens on a single traced graph, options
|
# TODO enable more and avoid returning None values
|
||||||
# are "disabled" by having them evaluate to an Identity op.
|
top_p=None, # input_batch.top_p[:padded_num_reqs],
|
||||||
# Note that initialization is dependent on num_samples.
|
top_k=None, # input_batch.top_k[:padded_num_reqs],
|
||||||
sampling_metadata_disable_value = \
|
min_p=input_batch.min_p[:padded_num_reqs],
|
||||||
TPUSupportedSamplingMetadata._get_default_params_values()
|
generators=input_batch.generators,
|
||||||
init_kwargs = dict()
|
indices_do_sample=indices_do_sample)
|
||||||
for p_name, (default_val,
|
|
||||||
dtype) in sampling_metadata_disable_value.items():
|
|
||||||
default_tensor = torch.full((num_samples, ),
|
|
||||||
default_val,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
init_kwargs[p_name] = default_tensor
|
|
||||||
|
|
||||||
return cls(**init_kwargs,
|
|
||||||
indices_do_sample=indices_do_sample,
|
|
||||||
do_argmax=do_argmax)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _validate_sampling_metadata(
|
|
||||||
sampling_metadata: SamplingMetadata) -> SamplingMetadata:
|
|
||||||
if sampling_metadata.all_greedy:
|
|
||||||
# Set to None since #13587. Make sure default isn't overruled.
|
|
||||||
assert sampling_metadata.temperature is None
|
|
||||||
return sampling_metadata
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_default_params_values():
|
|
||||||
return dict(
|
|
||||||
# Since #13587 greedy sampling requires branching off which leads
|
|
||||||
# to separate graphs. We set temp to noop and handle argmax here.
|
|
||||||
temperature=(1.0, torch.float32),
|
|
||||||
min_p=(0.0, torch.float32),
|
|
||||||
# strictly disabled for now
|
|
||||||
# top_k=(-1, torch.int32),
|
|
||||||
# top_p=(0.0, torch.float32),
|
|
||||||
# frequency_penalties=(0.0, torch.float32),
|
|
||||||
# presence_penalties=(0.0, torch.float32),
|
|
||||||
# repetition_penalties=(0.0, torch.float32),
|
|
||||||
)
|
|
||||||
|
|||||||
@ -279,9 +279,6 @@ class TPUModelRunner:
|
|||||||
req_data.num_computed_tokens)
|
req_data.num_computed_tokens)
|
||||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||||
req_index)
|
req_index)
|
||||||
# Check if the batch has changed. If not, we can skip copying the
|
|
||||||
# sampling metadata from CPU to GPU.
|
|
||||||
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
|
|
||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
@ -300,9 +297,6 @@ class TPUModelRunner:
|
|||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
self.input_batch.condense(removed_req_indices)
|
self.input_batch.condense(removed_req_indices)
|
||||||
|
|
||||||
# TODO This slices tensors to copy to device, triggering recompilation.
|
|
||||||
if batch_changed:
|
|
||||||
self.input_batch.refresh_sampling_metadata()
|
|
||||||
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
return len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
|
||||||
|
|
||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
@ -597,14 +591,12 @@ class TPUModelRunner:
|
|||||||
# then the embedding layer is not included in the CUDA graph.
|
# then the embedding layer is not included in the CUDA graph.
|
||||||
input_ids = self.input_ids
|
input_ids = self.input_ids
|
||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
# NOTE (NickLucche) here we sync with TPU: if there's any shape
|
# NOTE (NickLucche) here we sync with TPU: sampling params tensors
|
||||||
# mismatch in pre-processing, it will trigger a small recompilation
|
# are copied to device in chunks of pre-compiled padded shape to
|
||||||
# of the code thus far. Forward graph remains untouched.
|
# avoid recompilations.
|
||||||
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
tpu_sampling_metadata = TPUSupportedSamplingMetadata.\
|
||||||
from_sampling_metadata(sampling_metadata, logits_indices,
|
from_input_batch(self.input_batch, logits_indices)
|
||||||
num_reqs, self.device)
|
|
||||||
# Run the decoder
|
# Run the decoder
|
||||||
with set_forward_context(attn_metadata, self.vllm_config):
|
with set_forward_context(attn_metadata, self.vllm_config):
|
||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
@ -797,21 +789,19 @@ class TPUModelRunner:
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
while True:
|
while True:
|
||||||
# Default metadata is an all_greedy setup. But since the
|
|
||||||
# `do_argmax` flag is a tensor, we still compile the full graph
|
|
||||||
meta = self.input_batch.sampling_metadata
|
|
||||||
indices = torch.zeros(
|
indices = torch.zeros(
|
||||||
num_reqs_to_sample,
|
num_reqs_to_sample,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
xm.mark_step()
|
||||||
sampling_meta = TPUSupportedSamplingMetadata.\
|
sampling_meta = TPUSupportedSamplingMetadata.\
|
||||||
from_sampling_metadata(meta, indices,
|
from_input_batch(self.input_batch, indices)
|
||||||
num_reqs_to_sample, device)
|
|
||||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||||
num_reqs_to_sample)
|
num_reqs_to_sample)
|
||||||
self.model.sample_from_hidden(dummy_hidden, sampling_meta)
|
out = self.model.sample_from_hidden(dummy_hidden,
|
||||||
xm.mark_step()
|
sampling_meta)
|
||||||
|
out = out.cpu()
|
||||||
if num_reqs_to_sample >= self.max_num_reqs:
|
if num_reqs_to_sample >= self.max_num_reqs:
|
||||||
break
|
break
|
||||||
num_reqs_to_sample *= 2
|
num_reqs_to_sample *= 2
|
||||||
@ -910,6 +900,7 @@ class ModelWrapperV1(nn.Module):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
# @torch.compile(backend="openxla", fullgraph=True, dynamic=False)
|
||||||
def sample_from_hidden(
|
def sample_from_hidden(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@ -923,10 +914,9 @@ class ModelWrapperV1(nn.Module):
|
|||||||
sample_hidden_states = \
|
sample_hidden_states = \
|
||||||
hidden_states[sampling_metadata.indices_do_sample]
|
hidden_states[sampling_metadata.indices_do_sample]
|
||||||
logits = self.compute_logits(sample_hidden_states)
|
logits = self.compute_logits(sample_hidden_states)
|
||||||
# Greedy sampling can't be run without branching the graph on Sampler.
|
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||||
# Therefore do_argmax/all_greedy is checked here in a xla-friendly way.
|
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||||
# NOTE do_argmax is a scalar, this is just an optimized if/else.
|
out_tokens = torch.where(sampling_metadata.all_greedy,
|
||||||
out_tokens = torch.where(sampling_metadata.do_argmax,
|
|
||||||
torch.argmax(logits, dim=-1, keepdim=True),
|
torch.argmax(logits, dim=-1, keepdim=True),
|
||||||
self.sample(logits, sampling_metadata)\
|
self.sample(logits, sampling_metadata)\
|
||||||
.sampled_token_ids)
|
.sampled_token_ids)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user