mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-24 01:32:28 +08:00
[Model] Fix minimax model cache & lm_head precision (#19592)
Signed-off-by: qingjun <qingjun@minimaxi.com>
This commit is contained in:
parent
7e8d97dd3f
commit
a24cb91600
@ -856,7 +856,7 @@ class MiniMaxText01Model(nn.Module):
|
||||
self._dtype = _dummy.dtype
|
||||
del _dummy
|
||||
|
||||
self.minimax_cache = MinimaxCacheManager(dtype=self._dtype,
|
||||
self.minimax_cache = MinimaxCacheManager(dtype=torch.float32,
|
||||
cache_shape=self.cache_shape)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@ -1021,7 +1021,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.lm_head.float()
|
||||
flash_layer_count = sum(1 for attn_type in self.config.attn_type_list
|
||||
if attn_type == 1)
|
||||
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
|
||||
@ -1054,7 +1054,7 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
logits = self.logits_processor(self.lm_head, hidden_states.float(),
|
||||
sampling_metadata)
|
||||
|
||||
return logits
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user