mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 17:55:53 +08:00
[Core/Bugfix] Add query dtype as per FlashInfer API requirements. (#8173)
This commit is contained in:
parent
ba262c4e5a
commit
e39ebf5cf5
@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
|
|||||||
head_size,
|
head_size,
|
||||||
block_size,
|
block_size,
|
||||||
"NONE",
|
"NONE",
|
||||||
data_type=dtype)
|
data_type=dtype,
|
||||||
|
q_data_type=dtype)
|
||||||
output = wrapper.forward(query,
|
output = wrapper.forward(query,
|
||||||
kv_cache_fp8,
|
kv_cache_fp8,
|
||||||
logits_soft_cap=soft_cap,
|
logits_soft_cap=soft_cap,
|
||||||
|
|||||||
@ -224,6 +224,7 @@ class FlashInferState(AttentionState):
|
|||||||
query_start_loc=query_start_loc_host,
|
query_start_loc=query_start_loc_host,
|
||||||
device=self.runner.device,
|
device=self.runner.device,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
|
q_data_type=self.runner.model_config.dtype,
|
||||||
use_cuda_graph=True,
|
use_cuda_graph=True,
|
||||||
decode_wrapper=self._graph_decode_wrapper,
|
decode_wrapper=self._graph_decode_wrapper,
|
||||||
prefill_wrapper=None)
|
prefill_wrapper=None)
|
||||||
@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
page_size: Optional[int] = None
|
page_size: Optional[int] = None
|
||||||
# The data type of the paged kv cache
|
# The data type of the paged kv cache
|
||||||
data_type: torch.dtype = None
|
data_type: torch.dtype = None
|
||||||
|
# The data type of the query
|
||||||
|
q_data_type: torch.dtype = None
|
||||||
device: torch.device = torch.device("cuda")
|
device: torch.device = torch.device("cuda")
|
||||||
is_profile_run: bool = False
|
is_profile_run: bool = False
|
||||||
|
|
||||||
@ -353,7 +356,10 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
self.page_size,
|
self.page_size,
|
||||||
# Disable flashinfer's pos encoding and use vllm's rope.
|
# Disable flashinfer's pos encoding and use vllm's rope.
|
||||||
pos_encoding_mode="NONE",
|
pos_encoding_mode="NONE",
|
||||||
data_type=self.data_type)
|
# kv-cache data type.
|
||||||
|
data_type=self.data_type,
|
||||||
|
# query data type.
|
||||||
|
q_data_type=self.q_data_type)
|
||||||
|
|
||||||
def asdict_zerocopy(self,
|
def asdict_zerocopy(self,
|
||||||
skip_fields: Optional[Set[str]] = None
|
skip_fields: Optional[Set[str]] = None
|
||||||
@ -617,6 +623,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
device=device,
|
device=device,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
|
q_data_type=self.runner.model_config.dtype,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph,
|
||||||
is_profile_run=self.is_profile_run)
|
is_profile_run=self.is_profile_run)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user