mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-20 17:47:06 +08:00
Fixes IMA for TP w/ flex-attention (#19712)
Signed-off-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
parent
5b3ad5ecf2
commit
ddfed314f9
@ -51,7 +51,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
|||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
|
||||||
@ -66,7 +65,6 @@ def test_flex_attention_vs_default_backend(monkeypatch):
|
|||||||
# Run with default backend
|
# Run with default backend
|
||||||
with monkeypatch.context() as m:
|
with monkeypatch.context() as m:
|
||||||
m.setenv("VLLM_USE_V1", "1")
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
llm_default = LLM(
|
llm_default = LLM(
|
||||||
model_name,
|
model_name,
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature,
|
|||||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionMetadata, AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||||
@ -237,17 +236,13 @@ class FlexAttentionMetadata:
|
|||||||
|
|
||||||
def build_block_mask(self) -> BlockMask:
|
def build_block_mask(self) -> BlockMask:
|
||||||
assert self.mask_mod is not None
|
assert self.mask_mod is not None
|
||||||
# FIXME: With TP>1, create_block_mask_compiled will raise
|
return create_block_mask_compiled(
|
||||||
# CUDA error: an illegal memory access was encountered
|
|
||||||
create_block_mask_fn = (create_block_mask_compiled
|
|
||||||
if get_tensor_model_parallel_world_size() == 1
|
|
||||||
else create_block_mask)
|
|
||||||
return create_block_mask_fn(
|
|
||||||
self.mask_mod,
|
self.mask_mod,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
self.num_actual_tokens,
|
self.num_actual_tokens,
|
||||||
self.total_cache_tokens,
|
self.total_cache_tokens,
|
||||||
|
device=self.block_table.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -429,7 +424,6 @@ class FlexAttentionImpl(AttentionImpl):
|
|||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
assert output is not None, "Output tensor must be provided."
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
if output_scale is not None:
|
if output_scale is not None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"fused output quantization is not yet supported"
|
"fused output quantization is not yet supported"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user