[Bugfix] Fix TP > 1 for new granite (#8544)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde 2024-09-17 17:17:08 -06:00 committed by GitHub
parent 56c3de018c
commit 98f9713399
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -428,7 +428,8 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
logits /= self.config.logits_scaling
if logits is not None:
logits /= self.config.logits_scaling
return logits
def sample(