mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 05:17:03 +08:00
[V1] Integrate Piecewise CUDA graphs (#10058)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
9d59b75593
commit
4089985552
@ -496,8 +496,11 @@ class PiecewiseBackend:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
logger.info("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
|
||||
@ -51,6 +51,7 @@ class FlashAttentionMetadata:
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
@ -134,7 +135,9 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
output = torch.ops.vllm.unified_flash_attention(
|
||||
output = torch.empty_like(query)
|
||||
torch.ops.vllm.unified_flash_attention(
|
||||
output,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
@ -154,6 +157,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
|
||||
def unified_flash_attention(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -168,17 +172,17 @@ def unified_flash_attention(
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> None:
|
||||
current_metadata = get_forward_context()
|
||||
if current_metadata is None:
|
||||
# Profiling run.
|
||||
return torch.empty_like(query)
|
||||
return
|
||||
|
||||
assert current_metadata is not None
|
||||
assert isinstance(current_metadata, FlashAttentionMetadata)
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
@ -188,18 +192,18 @@ def unified_flash_attention(
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
kv_cache[0],
|
||||
kv_cache[1],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q=query,
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
@ -213,10 +217,13 @@ def unified_flash_attention(
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
attn_output = attn_output.view(num_actual_tokens, -1)
|
||||
# TODO(woosuk): Optimize this.
|
||||
output[:num_actual_tokens].copy_(attn_output)
|
||||
|
||||
|
||||
def unified_flash_attention_fake(
|
||||
output: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
@ -231,13 +238,13 @@ def unified_flash_attention_fake(
|
||||
window_size: Optional[List[int]] = None,
|
||||
alibi_slopes: Optional[torch.Tensor] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query)
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_flash_attention",
|
||||
op_func=unified_flash_attention,
|
||||
mutates_args=["kv_cache"],
|
||||
mutates_args=["kv_cache", "output"],
|
||||
fake_impl=unified_flash_attention_fake,
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set
|
||||
from unittest.mock import patch
|
||||
@ -7,11 +9,16 @@ import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MultiModalDataDict
|
||||
from vllm.plugins import set_compilation_config
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
|
||||
is_pin_memory_available)
|
||||
@ -86,6 +93,18 @@ class GPUModelRunner:
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
# Remove stopped requests from the cached states.
|
||||
# Keep the states of the pre-empted requests.
|
||||
@ -268,12 +287,16 @@ class GPUModelRunner:
|
||||
seq_start_loc_np[0] = 0
|
||||
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
|
||||
|
||||
input_ids = input_ids.to(self.device, non_blocking=True)
|
||||
positions = positions.to(self.device, non_blocking=True).long()
|
||||
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
|
||||
non_blocking=True)
|
||||
self.positions[:total_num_scheduled_tokens].copy_(positions,
|
||||
non_blocking=True)
|
||||
|
||||
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
|
||||
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
|
||||
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
|
||||
attn_metadata = FlashAttentionMetadata(
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
@ -287,7 +310,7 @@ class GPUModelRunner:
|
||||
# token from the partial request.
|
||||
# TODO: Support prompt logprobs.
|
||||
logits_indices = query_start_loc[1:] - 1
|
||||
return input_ids, positions, attn_metadata, logits_indices
|
||||
return attn_metadata, logits_indices
|
||||
|
||||
def _prepare_sampling(
|
||||
self,
|
||||
@ -310,16 +333,26 @@ class GPUModelRunner:
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> ModelRunnerOutput:
|
||||
self._update_states(scheduler_output)
|
||||
inputs = self._prepare_inputs(scheduler_output)
|
||||
input_ids, positions, attn_metadata, logits_indices = inputs
|
||||
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if (self.use_cuda_graph
|
||||
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
|
||||
# Use piecewise CUDA graphs.
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self._get_padded_batch_size(
|
||||
num_scheduled_tokens)
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = num_scheduled_tokens
|
||||
|
||||
with set_forward_context(attn_metadata):
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
input_ids=self.input_ids[:num_input_tokens],
|
||||
positions=self.positions[:num_input_tokens],
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
attn_metadata=None,
|
||||
)
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
|
||||
@ -371,6 +404,18 @@ class GPUModelRunner:
|
||||
return model_runner_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
if self.use_cuda_graph:
|
||||
# FIXME(woosuk): Currently, the custom ops are not supported
|
||||
# in the piecewise compilation mode. We rely on TorchInductor
|
||||
# to optimize the model.
|
||||
os.environ["VLLM_CUSTOM_OPS"] = "none"
|
||||
set_compilation_config(
|
||||
CompilationConfig(
|
||||
use_cudagraph=True,
|
||||
non_cudagraph_ops=["vllm.unified_flash_attention"],
|
||||
use_inductor=True,
|
||||
))
|
||||
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
|
||||
@ -381,26 +426,61 @@ class GPUModelRunner:
|
||||
self.model_memory_usage / float(2**30))
|
||||
|
||||
def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
|
||||
input_ids = torch.zeros(num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
positions = torch.zeros(num_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
kv_caches = [None for _ in range(self.num_attn_layers)]
|
||||
model(input_ids, positions, kv_caches, attn_metadata=None)
|
||||
return
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value `None`.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
# it is important to create tensors inside the loop, rather than
|
||||
# multiplying the list, to avoid Dynamo from treating them as
|
||||
# tensor aliasing.
|
||||
dummy_kv_caches = [
|
||||
torch.tensor([], dtype=torch.float32, device=self.device)
|
||||
for _ in range(self.num_attn_layers)
|
||||
]
|
||||
with set_forward_context(None): # noqa: SIM117
|
||||
with set_compile_context(self.cudagraph_batch_sizes):
|
||||
# Trigger compilation for general shape.
|
||||
model(self.input_ids,
|
||||
self.positions,
|
||||
dummy_kv_caches,
|
||||
attn_metadata=None)
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
self._dummy_run(self.model, self.max_num_tokens)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self) -> None:
|
||||
# TODO: Implement CUDA graph support.
|
||||
return
|
||||
if not self.use_cuda_graph:
|
||||
logger.warning(
|
||||
"Skipping CUDA graph capture. Please set "
|
||||
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
|
||||
CompilationLevel.PIECEWISE)
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
with set_forward_context(None):
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
for num_tokens in reversed(self.cudagraph_batch_sizes):
|
||||
self.model(
|
||||
self.input_ids[:num_tokens],
|
||||
self.positions[:num_tokens],
|
||||
kv_caches=self.kv_caches,
|
||||
attn_metadata=None,
|
||||
)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
elapsed_time = end_time - start_time
|
||||
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
|
||||
# This usually takes 5~20 seconds.
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def initialize_kv_cache(self, num_blocks: int) -> None:
|
||||
assert len(self.kv_caches) == 0
|
||||
@ -412,6 +492,13 @@ class GPUModelRunner:
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device))
|
||||
|
||||
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
|
||||
# TODO: Optimize this?
|
||||
for size in self.cudagraph_batch_sizes:
|
||||
if batch_size <= size:
|
||||
return size
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user