mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-17 06:17:06 +08:00
[XPU] Add gpt-oss model support for Intel GPU (#27786)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
4ea62b77f5
commit
18b39828d9
@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_supports_sinks() -> bool:
|
||||
if current_platform.is_xpu():
|
||||
return True
|
||||
else:
|
||||
return get_flash_attn_version() == 3
|
||||
|
||||
|
||||
def flash_attn_supports_mla():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
||||
else:
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
elif current_platform.is_xpu():
|
||||
logger.info_once("Using ipex marlin backend on XPU")
|
||||
return Mxfp4Backend.MARLIN
|
||||
elif current_platform.is_rocm() and has_triton_kernels():
|
||||
logger.info_once("Using Triton backend")
|
||||
return Mxfp4Backend.TRITON
|
||||
@ -188,7 +191,10 @@ class Mxfp4Config(QuantizationConfig):
|
||||
return UnquantizedLinearMethod()
|
||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
if current_platform.is_xpu():
|
||||
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||
else:
|
||||
return Mxfp4MoEMethod(layer.moe_config)
|
||||
elif isinstance(layer, Attention):
|
||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||
return None
|
||||
@ -245,7 +251,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
intermediate_size_per_partition_after_pad = round_up(
|
||||
intermediate_size_per_partition, 128
|
||||
)
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
if current_platform.is_xpu():
|
||||
hidden_size = round_up(hidden_size, 128)
|
||||
else:
|
||||
hidden_size = round_up(hidden_size, 256)
|
||||
|
||||
layer.params_dtype = params_dtype
|
||||
layer.num_experts = num_experts
|
||||
@ -1071,3 +1080,84 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||
|
||||
|
||||
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
super().__init__(moe_config)
|
||||
self.moe_config = moe_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
super().create_weights(
|
||||
layer,
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
self.original_hidden_size = hidden_size
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
|
||||
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
|
||||
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
w1_scale_inv=layer.w13_weight_scale,
|
||||
w2_scale_inv=layer.w2_weight_scale,
|
||||
w13_bias=layer.w13_bias,
|
||||
w2_bias=layer.w2_bias,
|
||||
is_mxfp4=True,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert activation == "swigluoai", (
|
||||
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
|
||||
) # noqa:
|
||||
hidden_size_pad = round_up(self.original_hidden_size, 128)
|
||||
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
|
||||
hidden_states = layer.ipex_fusion(
|
||||
x_pad,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
activation="swiglu_oai",
|
||||
)
|
||||
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
|
||||
return hidden_states
|
||||
|
||||
@ -337,9 +337,6 @@ class GptOssModel(nn.Module):
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
# FIXME(woosuk): Remove this after testing.
|
||||
weight = weight.cuda()
|
||||
|
||||
if ".w13_weight_scale" in name:
|
||||
# Handle MLP gate and up projection weights scale
|
||||
if use_ep:
|
||||
|
||||
@ -27,6 +27,7 @@ from vllm.attention.utils.fa_utils import (
|
||||
|
||||
if is_flash_attn_varlen_func_available():
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_sinks,
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
reshape_and_cache_flash,
|
||||
@ -497,7 +498,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.vllm_flash_attn_version == 3, (
|
||||
assert flash_attn_supports_sinks(), (
|
||||
"Sinks are only supported in FlashAttention 3"
|
||||
)
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user