mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-01 06:17:03 +08:00
minimize changes to gpu_model_runner.py
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
f33ec6d2d2
commit
34fb0cbbd0
@ -179,7 +179,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if attn_metadata is None:
|
if attn_metadata is None:
|
||||||
# Dynamic shape profiling run.
|
# Profiling run.
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# IMPORTANT!
|
# IMPORTANT!
|
||||||
@ -193,13 +193,17 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
# Reshape the input keys and values and store them in the cache.
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||||
|
# not padded. However, we don't need to do key[:num_actual_tokens] and
|
||||||
|
# value[:num_actual_tokens] because the reshape_and_cache_flash op uses
|
||||||
|
# the slot_mapping's shape to determine the number of actual tokens.
|
||||||
key_cache, value_cache = kv_cache.unbind(0)
|
key_cache, value_cache = kv_cache.unbind(0)
|
||||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
attn_metadata.slot_mapping[:num_actual_tokens],
|
attn_metadata.slot_mapping,
|
||||||
self.kv_cache_dtype,
|
self.kv_cache_dtype,
|
||||||
layer._k_scale,
|
layer._k_scale,
|
||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
@ -208,17 +212,14 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
if not attn_metadata.use_cascade:
|
if not attn_metadata.use_cascade:
|
||||||
# Regular attention (common case).
|
# Regular attention (common case).
|
||||||
batch_size = attn_metadata.block_table.shape[0]
|
|
||||||
|
|
||||||
#TODO: Do we need to slice by [:batch_size+1]?
|
|
||||||
flash_attn_varlen_func(
|
flash_attn_varlen_func(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
k=key_cache,
|
k=key_cache,
|
||||||
v=value_cache,
|
v=value_cache,
|
||||||
out=output[:num_actual_tokens],
|
out=output[:num_actual_tokens],
|
||||||
cu_seqlens_q=attn_metadata.query_start_loc[:batch_size+1],
|
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||||
max_seqlen_q=attn_metadata.max_query_len,
|
max_seqlen_q=attn_metadata.max_query_len,
|
||||||
seqused_k=attn_metadata.seq_lens[:batch_size],
|
seqused_k=attn_metadata.seq_lens,
|
||||||
max_seqlen_k=attn_metadata.max_seq_len,
|
max_seqlen_k=attn_metadata.max_seq_len,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
causal=True,
|
causal=True,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import CompilationLevel, VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.inputs import INPUT_REGISTRY
|
from vllm.inputs import INPUT_REGISTRY
|
||||||
@ -122,9 +122,6 @@ class GPUModelRunner:
|
|||||||
vocab_size=model_config.get_vocab_size(),
|
vocab_size=model_config.get_vocab_size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
|
||||||
# == CompilationLevel.PIECEWISE
|
|
||||||
# and not self.model_config.enforce_eager)
|
|
||||||
self.use_cuda_graph = not self.model_config.enforce_eager
|
self.use_cuda_graph = not self.model_config.enforce_eager
|
||||||
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
|
||||||
# The convention is different.
|
# The convention is different.
|
||||||
@ -467,7 +464,8 @@ class GPUModelRunner:
|
|||||||
self.input_batch.block_table.get_device_tensor()[num_reqs:].fill_(-1)
|
self.input_batch.block_table.get_device_tensor()[num_reqs:].fill_(-1)
|
||||||
|
|
||||||
# Fill with -1s -- needed for reshape_and_cache
|
# Fill with -1s -- needed for reshape_and_cache
|
||||||
self.slot_mapping[total_num_scheduled_tokens:].fill_(-1) # Definitely needed
|
self.slot_mapping[total_num_scheduled_tokens:].fill_(
|
||||||
|
-1) # Definitely needed
|
||||||
|
|
||||||
# Prepare for cascade attention if needed.
|
# Prepare for cascade attention if needed.
|
||||||
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
common_prefix_len = (scheduler_output.num_common_prefix_blocks *
|
||||||
@ -550,12 +548,12 @@ class GPUModelRunner:
|
|||||||
attn_metadata = FlashAttentionMetadata(
|
attn_metadata = FlashAttentionMetadata(
|
||||||
num_actual_tokens=total_num_scheduled_tokens,
|
num_actual_tokens=total_num_scheduled_tokens,
|
||||||
max_query_len=max_num_scheduled_tokens,
|
max_query_len=max_num_scheduled_tokens,
|
||||||
query_start_loc=self.query_start_loc,
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens[:num_reqs],
|
||||||
block_table=(
|
block_table=(
|
||||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping[:total_num_scheduled_tokens],
|
||||||
# Cascade stuff
|
# Cascade stuff
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
@ -914,12 +912,12 @@ class GPUModelRunner:
|
|||||||
return FlashAttentionMetadata(
|
return FlashAttentionMetadata(
|
||||||
num_actual_tokens=num_tokens,
|
num_actual_tokens=num_tokens,
|
||||||
max_query_len=max_query_len,
|
max_query_len=max_query_len,
|
||||||
query_start_loc=self.query_start_loc,
|
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens[:num_reqs],
|
||||||
block_table=(
|
block_table=(
|
||||||
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
|
||||||
slot_mapping=self.slot_mapping,
|
slot_mapping=self.slot_mapping[:max_seq_len],
|
||||||
# Cascade stuff. Non-piecewise CUDA graphs NYI
|
# Cascade stuff. Non-piecewise CUDA graphs NYI
|
||||||
use_cascade=False,
|
use_cascade=False,
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user