mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-27 04:27:53 +08:00
[FalconH1] Fix output dtype in RMSNorm fallback path for Falcon-H1 (e.g. 0.5B) (#18500)
Signed-off-by: dhia.rhaiem <dhia.rhaiem@tii.ae> Co-authored-by: younesbelkada <younesbelkada@gmail.com> Co-authored-by: Ilyas Chahed <ilyas.chahed@tii.ae> Co-authored-by: Jingwei Zuo <jingwei.zuo@tii.ae>
This commit is contained in:
parent
1f079540db
commit
20bd6f4d2e
@ -77,7 +77,7 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
input_dtype = x.dtype
|
input_dtype = x.dtype
|
||||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||||
if not self.use_rms_norm:
|
if not self.use_rms_norm:
|
||||||
return x
|
return x.to(input_dtype)
|
||||||
|
|
||||||
if self.n_groups == 1:
|
if self.n_groups == 1:
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@ -117,9 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
gate: torch.Tensor,
|
gate: torch.Tensor,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
input_dtype = x.dtype
|
||||||
if not self.use_rms_norm:
|
if not self.use_rms_norm:
|
||||||
return x * nn.functional.silu(gate.to(torch.float32))
|
# Keep gate in float32 for numerical stability during silu
|
||||||
|
return x * nn.functional.silu(gate.to(
|
||||||
|
torch.float32)).to(input_dtype)
|
||||||
|
|
||||||
if self.tp_size > 1 or self.n_groups != 1:
|
if self.tp_size > 1 or self.n_groups != 1:
|
||||||
return self.forward_native(x, gate)
|
return self.forward_native(x, gate)
|
||||||
|
|||||||
@ -453,7 +453,6 @@ class FalconH1Model(nn.Module):
|
|||||||
attn_metadata = get_forward_context().attn_metadata
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
mamba2_metadata = prepare_mamba2_metadata(
|
mamba2_metadata = prepare_mamba2_metadata(
|
||||||
chunk_size=self.config.mamba_chunk_size,
|
chunk_size=self.config.mamba_chunk_size,
|
||||||
input_ids=input_ids,
|
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
)
|
)
|
||||||
if get_pp_group().is_first_rank:
|
if get_pp_group().is_first_rank:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user