[XPU] Add gpt-oss model support for Intel GPU (#27786)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2025-11-05 10:17:23 +08:00 committed by GitHub
parent 4ea62b77f5
commit 18b39828d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 101 additions and 6 deletions

View File

@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
)
def flash_attn_supports_sinks() -> bool:
if current_platform.is_xpu():
return True
else:
return get_flash_attn_version() == 3
def flash_attn_supports_mla():
from vllm.platforms import current_platform

View File

@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
else:
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
elif current_platform.is_xpu():
logger.info_once("Using ipex marlin backend on XPU")
return Mxfp4Backend.MARLIN
elif current_platform.is_rocm() and has_triton_kernels():
logger.info_once("Using Triton backend")
return Mxfp4Backend.TRITON
@ -188,7 +191,10 @@ class Mxfp4Config(QuantizationConfig):
return UnquantizedLinearMethod()
raise NotImplementedError("Mxfp4 linear layer is not implemented")
elif isinstance(layer, FusedMoE):
return Mxfp4MoEMethod(layer.moe_config)
if current_platform.is_xpu():
return IpexMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
raise NotImplementedError("Mxfp4 attention layer is not implemented")
return None
@ -245,7 +251,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
intermediate_size_per_partition_after_pad = round_up(
intermediate_size_per_partition, 128
)
hidden_size = round_up(hidden_size, 256)
if current_platform.is_xpu():
hidden_size = round_up(hidden_size, 128)
else:
hidden_size = round_up(hidden_size, 256)
layer.params_dtype = params_dtype
layer.num_experts = num_experts
@ -1071,3 +1080,84 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
)
else:
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.original_hidden_size = hidden_size
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
import intel_extension_for_pytorch as ipex
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
layer.w13_weight,
layer.w2_weight,
w1_scale_inv=layer.w13_weight_scale,
w2_scale_inv=layer.w2_weight_scale,
w13_bias=layer.w13_bias,
w2_bias=layer.w2_bias,
is_mxfp4=True,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
assert activation == "swigluoai", (
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
) # noqa:
hidden_size_pad = round_up(self.original_hidden_size, 128)
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
hidden_states = layer.ipex_fusion(
x_pad,
use_grouped_topk,
top_k,
router_logits,
renormalize,
topk_group,
num_expert_group,
activation="swiglu_oai",
)
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
return hidden_states

View File

@ -337,9 +337,6 @@ class GptOssModel(nn.Module):
if is_pp_missing_parameter(name, self):
continue
# FIXME(woosuk): Remove this after testing.
weight = weight.cuda()
if ".w13_weight_scale" in name:
# Handle MLP gate and up projection weights scale
if use_ep:

View File

@ -27,6 +27,7 @@ from vllm.attention.utils.fa_utils import (
if is_flash_attn_varlen_func_available():
from vllm.attention.utils.fa_utils import (
flash_attn_supports_sinks,
flash_attn_varlen_func,
get_scheduler_metadata,
reshape_and_cache_flash,
@ -497,7 +498,7 @@ class FlashAttentionImpl(AttentionImpl):
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
)
assert self.sinks.shape[0] == num_heads, (