[Model] Fix minimax model cache & lm_head precision (#19592)

Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
qscqesze 2025-06-13 20:08:20 +08:00 committed by GitHub
parent 7e8d97dd3f
commit a24cb91600
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -856,7 +856,7 @@ class MiniMaxText01Model(nn.Module):
self._dtype = _dummy.dtype
del _dummy
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
@ -1021,7 +1021,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
else:
self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
if attn_type == 1)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
@ -1054,7 +1054,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states.float(),
sampling_metadata)
return logits