mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[Temp fix] Disable torch.compile for Qwen2.5 VL's VisionBlock temporarily. (#27760)
Signed-off-by: Chenheli Hua <huachenheli@outlook.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
parent
b5d90f7400
commit
48eb8eba58
@ -460,15 +460,17 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile(
|
# (FIXME): Enable this after dynamic slicing is fixed
|
||||||
dynamic_arg_dims={
|
# See https://github.com/vllm-project/vllm/pull/27760
|
||||||
"x": 0,
|
# @support_torch_compile(
|
||||||
"cu_seqlens": 0,
|
# dynamic_arg_dims={
|
||||||
"rotary_pos_emb": 0,
|
# "x": 0,
|
||||||
"seqlens": 0,
|
# "cu_seqlens": 0,
|
||||||
},
|
# "rotary_pos_emb": 0,
|
||||||
mark_unbacked_dims={"seqlens": 0},
|
# "seqlens": 0,
|
||||||
)
|
# },
|
||||||
|
# mark_unbacked_dims={"seqlens": 0},
|
||||||
|
# )
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user