mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 19:37:58 +08:00
[XPU] Add gpt-oss model support for Intel GPU (#27786)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
4ea62b77f5
commit
18b39828d9
@ -80,6 +80,13 @@ def flash_attn_supports_fp8() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_supports_sinks() -> bool:
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return get_flash_attn_version() == 3
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_supports_mla():
|
def flash_attn_supports_mla():
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|||||||
@ -142,6 +142,9 @@ def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
|
|||||||
else:
|
else:
|
||||||
logger.info_once("Using Triton backend")
|
logger.info_once("Using Triton backend")
|
||||||
return Mxfp4Backend.TRITON
|
return Mxfp4Backend.TRITON
|
||||||
|
elif current_platform.is_xpu():
|
||||||
|
logger.info_once("Using ipex marlin backend on XPU")
|
||||||
|
return Mxfp4Backend.MARLIN
|
||||||
elif current_platform.is_rocm() and has_triton_kernels():
|
elif current_platform.is_rocm() and has_triton_kernels():
|
||||||
logger.info_once("Using Triton backend")
|
logger.info_once("Using Triton backend")
|
||||||
return Mxfp4Backend.TRITON
|
return Mxfp4Backend.TRITON
|
||||||
@ -188,7 +191,10 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
raise NotImplementedError("Mxfp4 linear layer is not implemented")
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Mxfp4MoEMethod(layer.moe_config)
|
if current_platform.is_xpu():
|
||||||
|
return IpexMxfp4MoEMethod(layer.moe_config)
|
||||||
|
else:
|
||||||
|
return Mxfp4MoEMethod(layer.moe_config)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
raise NotImplementedError("Mxfp4 attention layer is not implemented")
|
||||||
return None
|
return None
|
||||||
@ -245,7 +251,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
intermediate_size_per_partition_after_pad = round_up(
|
intermediate_size_per_partition_after_pad = round_up(
|
||||||
intermediate_size_per_partition, 128
|
intermediate_size_per_partition, 128
|
||||||
)
|
)
|
||||||
hidden_size = round_up(hidden_size, 256)
|
if current_platform.is_xpu():
|
||||||
|
hidden_size = round_up(hidden_size, 128)
|
||||||
|
else:
|
||||||
|
hidden_size = round_up(hidden_size, 256)
|
||||||
|
|
||||||
layer.params_dtype = params_dtype
|
layer.params_dtype = params_dtype
|
||||||
layer.num_experts = num_experts
|
layer.num_experts = num_experts
|
||||||
@ -1071,3 +1080,84 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
raise ValueError(f"Unsupported backend: {self.mxfp4_backend}")
|
||||||
|
|
||||||
|
|
||||||
|
class IpexMxfp4MoEMethod(Mxfp4MoEMethod):
|
||||||
|
def __init__(self, moe_config: FusedMoEConfig):
|
||||||
|
super().__init__(moe_config)
|
||||||
|
self.moe_config = moe_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
super().create_weights(
|
||||||
|
layer,
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size_per_partition,
|
||||||
|
params_dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
)
|
||||||
|
self.original_hidden_size = hidden_size
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
layer.w13_weight.data = layer.w13_weight.data.view(torch.int32)
|
||||||
|
layer.w2_weight.data = layer.w2_weight.data.view(torch.int32)
|
||||||
|
layer.ipex_fusion = ipex.llm.modules.GatedMLPMOE(
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
w1_scale_inv=layer.w13_weight_scale,
|
||||||
|
w2_scale_inv=layer.w2_weight_scale,
|
||||||
|
w13_bias=layer.w13_bias,
|
||||||
|
w2_bias=layer.w2_bias,
|
||||||
|
is_mxfp4=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: int | None = None,
|
||||||
|
num_expert_group: int | None = None,
|
||||||
|
global_num_experts: int = -1,
|
||||||
|
expert_map: torch.Tensor | None = None,
|
||||||
|
custom_routing_function: Callable | None = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
routed_scaling_factor: float = 1.0,
|
||||||
|
e_score_correction_bias: torch.Tensor | None = None,
|
||||||
|
apply_router_weight_on_input: bool = False,
|
||||||
|
activation: str = "silu",
|
||||||
|
enable_eplb: bool = False,
|
||||||
|
expert_load_view: torch.Tensor | None = None,
|
||||||
|
logical_to_physical_map: torch.Tensor | None = None,
|
||||||
|
logical_replica_count: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert activation == "swigluoai", (
|
||||||
|
"Only swiglu_oai activation is supported for IPEX MXFP4 MoE"
|
||||||
|
) # noqa:
|
||||||
|
hidden_size_pad = round_up(self.original_hidden_size, 128)
|
||||||
|
x_pad = torch.nn.functional.pad(x, (0, hidden_size_pad - x.size(-1)))
|
||||||
|
hidden_states = layer.ipex_fusion(
|
||||||
|
x_pad,
|
||||||
|
use_grouped_topk,
|
||||||
|
top_k,
|
||||||
|
router_logits,
|
||||||
|
renormalize,
|
||||||
|
topk_group,
|
||||||
|
num_expert_group,
|
||||||
|
activation="swiglu_oai",
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states[..., : self.original_hidden_size].contiguous()
|
||||||
|
return hidden_states
|
||||||
|
|||||||
@ -337,9 +337,6 @@ class GptOssModel(nn.Module):
|
|||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# FIXME(woosuk): Remove this after testing.
|
|
||||||
weight = weight.cuda()
|
|
||||||
|
|
||||||
if ".w13_weight_scale" in name:
|
if ".w13_weight_scale" in name:
|
||||||
# Handle MLP gate and up projection weights scale
|
# Handle MLP gate and up projection weights scale
|
||||||
if use_ep:
|
if use_ep:
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from vllm.attention.utils.fa_utils import (
|
|||||||
|
|
||||||
if is_flash_attn_varlen_func_available():
|
if is_flash_attn_varlen_func_available():
|
||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
|
flash_attn_supports_sinks,
|
||||||
flash_attn_varlen_func,
|
flash_attn_varlen_func,
|
||||||
get_scheduler_metadata,
|
get_scheduler_metadata,
|
||||||
reshape_and_cache_flash,
|
reshape_and_cache_flash,
|
||||||
@ -497,7 +498,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.sinks = sinks
|
self.sinks = sinks
|
||||||
if self.sinks is not None:
|
if self.sinks is not None:
|
||||||
assert self.vllm_flash_attn_version == 3, (
|
assert flash_attn_supports_sinks(), (
|
||||||
"Sinks are only supported in FlashAttention 3"
|
"Sinks are only supported in FlashAttention 3"
|
||||||
)
|
)
|
||||||
assert self.sinks.shape[0] == num_heads, (
|
assert self.sinks.shape[0] == num_heads, (
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user