[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)

Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae>
Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
Dhia Eddine Rhaiem 2025-05-22 06:23:59 +04:00 committed by GitHub
parent 1f079540db
commit 20bd6f4d2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 4 deletions

View File

@ -77,7 +77,7 @@ class Mixer2RMSNormGated(CustomOp):
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
@ -117,9 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
return x * nn.functional.silu(gate.to(torch.float32))
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)

View File

@ -453,7 +453,6 @@ class FalconH1Model(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
input_ids=input_ids,
attn_metadata=attn_metadata,
)
if get_pp_group().is_first_rank: