[Model] Fix Skywork R1V mlp (#26673)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-10-13 13:42:17 +08:00 committed by GitHub
parent 3cd36660f7
commit 98f30b8cba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -691,7 +691,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
self.img_context_token_id = None
self.visual_token_mask = None
@ -738,7 +740,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
def _init_mlp1(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig,
prefix: str = "",
) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
@ -748,9 +755,17 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
llm_hidden_size,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.1",
),
nn.GELU(),
ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False),
ReplicatedLinear(
llm_hidden_size,
llm_hidden_size,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.3",
),
)
def pixel_shuffle(self, x, scale_factor=0.5):