mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 01:25:01 +08:00
[FlashInfer] Cache hyper params in metadata builder (#23732)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
04ff1e43fb
commit
11eddf02f0
@ -214,6 +214,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
# TODO: discard this for trtllm-gen backend
|
# TODO: discard this for trtllm-gen backend
|
||||||
self.global_hyperparameters = infer_global_hyperparameters(
|
self.global_hyperparameters = infer_global_hyperparameters(
|
||||||
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
get_per_layer_parameters(vllm_config, layer_names, FlashInferImpl))
|
||||||
|
self.sm_scale = self.global_hyperparameters.sm_scale
|
||||||
|
self.window_left = self.global_hyperparameters.window_left
|
||||||
|
self.logits_soft_cap = self.global_hyperparameters.logits_soft_cap
|
||||||
|
self.has_sinks = self.global_hyperparameters.has_sinks
|
||||||
|
|
||||||
# Preparing persistent buffers (device-side)
|
# Preparing persistent buffers (device-side)
|
||||||
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
self.paged_kv_indptr = torch.zeros(max_num_reqs + 1,
|
||||||
@ -381,8 +385,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check if any layer uses sinks (requires TRTLLM attention)
|
# Check if any layer uses sinks (requires TRTLLM attention)
|
||||||
has_sinks = self.global_hyperparameters.has_sinks
|
|
||||||
|
|
||||||
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
num_prefill_tokens,
|
num_prefill_tokens,
|
||||||
@ -390,7 +392,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.cache_dtype,
|
self.cache_dtype,
|
||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
is_prefill=True,
|
is_prefill=True,
|
||||||
has_sinks=has_sinks)
|
has_sinks=self.has_sinks)
|
||||||
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
|
||||||
self.num_kv_heads,
|
self.num_kv_heads,
|
||||||
num_decode_tokens,
|
num_decode_tokens,
|
||||||
@ -398,7 +400,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.cache_dtype,
|
self.cache_dtype,
|
||||||
self.q_data_type,
|
self.q_data_type,
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
has_sinks=has_sinks)
|
has_sinks=self.has_sinks)
|
||||||
|
|
||||||
attn_metadata = FlashInferMetadata(
|
attn_metadata = FlashInferMetadata(
|
||||||
num_actual_tokens=num_actual_tokens,
|
num_actual_tokens=num_actual_tokens,
|
||||||
@ -433,9 +435,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
sm_scale=self.sm_scale,
|
||||||
window_left=self.global_hyperparameters.window_left,
|
window_left=self.window_left,
|
||||||
logits_soft_cap=self.global_hyperparameters.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
@ -472,10 +474,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
causal=True,
|
causal=True,
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
sm_scale=self.sm_scale,
|
||||||
window_left=self.global_hyperparameters.window_left,
|
window_left=self.window_left,
|
||||||
logits_soft_cap=self.global_hyperparameters.
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
logits_soft_cap,
|
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
@ -525,10 +526,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
sm_scale=self.global_hyperparameters.sm_scale,
|
sm_scale=self.sm_scale,
|
||||||
window_left=self.global_hyperparameters.window_left,
|
window_left=self.window_left,
|
||||||
logits_soft_cap=self.global_hyperparameters.
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
logits_soft_cap,
|
|
||||||
q_data_type=self.q_data_type,
|
q_data_type=self.q_data_type,
|
||||||
kv_data_type=self.kv_cache_dtype,
|
kv_data_type=self.kv_cache_dtype,
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user