From a24cb91600bdfcafd4c18b6647e9184a2b47fcb4 Mon Sep 17 00:00:00 2001 From: qscqesze Date: Fri, 13 Jun 2025 20:08:20 +0800 Subject: [PATCH] [Model] Fix minimax model cache & lm_head precision (#19592) Signed-off-by: qingjun --- vllm/model_executor/models/minimax_text_01.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/minimax_text_01.py b/vllm/model_executor/models/minimax_text_01.py index 02800449bda3c..87480796ae98f 100644 --- a/vllm/model_executor/models/minimax_text_01.py +++ b/vllm/model_executor/models/minimax_text_01.py @@ -856,7 +856,7 @@ class MiniMaxText01Model(nn.Module): self._dtype = _dummy.dtype del _dummy - self.minimax_cache = MinimaxCacheManager(dtype=self._dtype, + self.minimax_cache = MinimaxCacheManager(dtype=torch.float32, cache_shape=self.cache_shape) rope_theta = getattr(config, "rope_theta", 10000) @@ -1021,7 +1021,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, else: self.lm_head = PPMissingLayer() - + self.lm_head.float() flash_layer_count = sum(1 for attn_type in self.config.attn_type_list if attn_type == 1) self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)] @@ -1054,7 +1054,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid, def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: - logits = self.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.lm_head, hidden_states.float(), sampling_metadata) return logits