mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-06 11:11:19 +08:00
[TPU] fix kv cache dtype in model runner (#19244)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
90b78ec5f9
commit
0d49483ea9
@ -29,7 +29,8 @@ from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargs,
|
|||||||
PlaceholderRange)
|
PlaceholderRange)
|
||||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, cdiv,
|
||||||
|
is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
@ -138,6 +139,11 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
self.kv_cache_dtype = self.dtype
|
||||||
|
else:
|
||||||
|
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
cache_config.cache_dtype]
|
||||||
self._hidden_states_dtype = self.dtype
|
self._hidden_states_dtype = self.dtype
|
||||||
|
|
||||||
self.is_multimodal_model = model_config.is_multimodal_model
|
self.is_multimodal_model = model_config.is_multimodal_model
|
||||||
@ -480,7 +486,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=attn_module.dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
sliding_window=attn_module.sliding_window,
|
sliding_window=attn_module.sliding_window,
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
)
|
)
|
||||||
@ -489,7 +495,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=attn_module.dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
use_mla=False,
|
use_mla=False,
|
||||||
)
|
)
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user