mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 20:04:58 +08:00
[XPU] Support Triton path for LoRA operations on XPU (#28511)
Signed-off-by: Fanli Lin <fanli.lin@intel.com>
This commit is contained in:
parent
7dca0c90cb
commit
dbbe0c756a
@ -48,6 +48,7 @@ def _lora_expand_kernel(
|
|||||||
SLICE_NUM: tl.constexpr,
|
SLICE_NUM: tl.constexpr,
|
||||||
SAME_STRIDE: tl.constexpr,
|
SAME_STRIDE: tl.constexpr,
|
||||||
USE_GDC: tl.constexpr,
|
USE_GDC: tl.constexpr,
|
||||||
|
launch_pdl: tl.constexpr,
|
||||||
):
|
):
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||||
|
|||||||
@ -46,6 +46,7 @@ def _lora_shrink_kernel(
|
|||||||
GROUP_SIZE_M: tl.constexpr,
|
GROUP_SIZE_M: tl.constexpr,
|
||||||
SLICE_NUM: tl.constexpr,
|
SLICE_NUM: tl.constexpr,
|
||||||
USE_GDC: tl.constexpr,
|
USE_GDC: tl.constexpr,
|
||||||
|
launch_pdl: tl.constexpr,
|
||||||
):
|
):
|
||||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||||
cta_m_num = tl.cdiv(M, BLOCK_M)
|
cta_m_num = tl.cdiv(M, BLOCK_M)
|
||||||
|
|||||||
@ -101,7 +101,11 @@ class XPUPlatform(Platform):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
|
xpu_use_triton_kernel = os.getenv("XPU_USE_TRITON_KERNEL", "0") == "1"
|
||||||
|
if not xpu_use_triton_kernel:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_xpu.PunicaWrapperXPU"
|
||||||
|
else:
|
||||||
|
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user