minimize changes to gpu_model_runner.py

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith 2025-02-04 21:32:41 +00:00
parent f33ec6d2d2
commit 34fb0cbbd0
2 changed files with 17 additions and 18 deletions

View File

@ -179,7 +179,7 @@ class FlashAttentionImpl(AttentionImpl):
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if attn_metadata is None: if attn_metadata is None:
# Dynamic shape profiling run. # Profiling run.
return output return output
# IMPORTANT! # IMPORTANT!
@ -193,13 +193,17 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens] and
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
# the slot_mapping's shape to determine the number of actual tokens.
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash( torch.ops._C_cache_ops.reshape_and_cache_flash(
key, key,
value, value,
key_cache, key_cache,
value_cache, value_cache,
attn_metadata.slot_mapping[:num_actual_tokens], attn_metadata.slot_mapping,
self.kv_cache_dtype, self.kv_cache_dtype,
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
@ -208,17 +212,14 @@ class FlashAttentionImpl(AttentionImpl):
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
if not attn_metadata.use_cascade: if not attn_metadata.use_cascade:
# Regular attention (common case). # Regular attention (common case).
batch_size = attn_metadata.block_table.shape[0]
#TODO: Do we need to slice by [:batch_size+1]?
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
out=output[:num_actual_tokens], out=output[:num_actual_tokens],
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1], cu_seqlens_q=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=attn_metadata.max_query_len,
seqused_k=attn_metadata.seq_lens[:batch_size], seqused_k=attn_metadata.seq_lens,
max_seqlen_k=attn_metadata.max_seq_len, max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,

View File

@ -11,7 +11,7 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import VllmConfig
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY from vllm.inputs import INPUT_REGISTRY
@ -122,9 +122,6 @@ class GPUModelRunner:
vocab_size=model_config.get_vocab_size(), vocab_size=model_config.get_vocab_size(),
) )
# self.use_cuda_graph = (self.vllm_config.compilation_config.level
# == CompilationLevel.PIECEWISE
# and not self.model_config.enforce_eager)
self.use_cuda_graph = not self.model_config.enforce_eager self.use_cuda_graph = not self.model_config.enforce_eager
# TODO(woosuk): Provide an option to tune the max cudagraph batch size. # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different. # The convention is different.
@ -467,7 +464,8 @@ class GPUModelRunner:
self.input_batch.block_table.get_device_tensor()[num_reqs:].fill_(-1) self.input_batch.block_table.get_device_tensor()[num_reqs:].fill_(-1)
# Fill with -1s -- needed for reshape_and_cache # Fill with -1s -- needed for reshape_and_cache
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) # Definitely needed self.slot_mapping[total_num_scheduled_tokens:].fill_(
-1) # Definitely needed
# Prepare for cascade attention if needed. # Prepare for cascade attention if needed.
common_prefix_len = (scheduler_output.num_common_prefix_blocks * common_prefix_len = (scheduler_output.num_common_prefix_blocks *
@ -550,12 +548,12 @@ class GPUModelRunner:
attn_metadata = FlashAttentionMetadata( attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
query_start_loc=self.query_start_loc, query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=self.seq_lens, seq_lens=self.seq_lens[:num_reqs],
block_table=( block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]), self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping[:total_num_scheduled_tokens],
# Cascade stuff # Cascade stuff
use_cascade=use_cascade, use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
@ -914,12 +912,12 @@ class GPUModelRunner:
return FlashAttentionMetadata( return FlashAttentionMetadata(
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
query_start_loc=self.query_start_loc, query_start_loc=self.query_start_loc[:num_reqs + 1],
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=self.seq_lens, seq_lens=self.seq_lens[:num_reqs],
block_table=( block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]), self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=self.slot_mapping, slot_mapping=self.slot_mapping[:max_seq_len],
# Cascade stuff. Non-piecewise CUDA graphs NYI # Cascade stuff. Non-piecewise CUDA graphs NYI
use_cascade=False, use_cascade=False,
common_prefix_len=0, common_prefix_len=0,