mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:34:57 +08:00
[DCP] Support Decode Context Parallel (DCP) for GQA with FlashAttention (#24864)
Signed-off-by: yuanyongjie.yyj <yuanyongjie.yyj@antgroup.com> Signed-off-by: FENP <32334296+FENP@users.noreply.github.com> Signed-off-by: Jaya Yuan <yuanyongjie.yyj@antgroup.com>
This commit is contained in:
parent
fdd32750f0
commit
ea97940d6c
@ -204,17 +204,21 @@ def _compare_cp_with_tp(
|
|||||||
|
|
||||||
|
|
||||||
CP_TEXT_GENERATION_MODELS = {
|
CP_TEXT_GENERATION_MODELS = {
|
||||||
# [MLA attention only]
|
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
"deepseek-ai/DeepSeek-V2-Lite-Chat": [
|
||||||
CPTestSettings.detailed(),
|
CPTestSettings.detailed(),
|
||||||
CPTestSettings.detailed(tp_base=2),
|
CPTestSettings.detailed(tp_base=2),
|
||||||
],
|
],
|
||||||
|
"bigcode/gpt_bigcode-santacoder": [
|
||||||
|
CPTestSettings.detailed(),
|
||||||
|
CPTestSettings.detailed(tp_base=2),
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
CP_TEST_MODELS = [
|
CP_TEST_MODELS = [
|
||||||
# TODO support other models
|
# TODO support other models
|
||||||
# [LANGUAGE GENERATION]
|
# [LANGUAGE GENERATION]
|
||||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||||
|
"bigcode/gpt_bigcode-santacoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -262,7 +262,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
|
"GPT2LMHeadModel": _HfExamplesInfo("openai-community/gpt2", {"alias": "gpt2"}),
|
||||||
"GPTBigCodeForCausalLM": _HfExamplesInfo(
|
"GPTBigCodeForCausalLM": _HfExamplesInfo(
|
||||||
"bigcode/starcoder",
|
"bigcode/starcoder",
|
||||||
extras={"tiny": "bigcode/tiny_starcoder_py"},
|
extras={
|
||||||
|
"tiny": "bigcode/tiny_starcoder_py",
|
||||||
|
"santacoder": "bigcode/gpt_bigcode-santacoder",
|
||||||
|
},
|
||||||
min_transformers_version="4.55.1",
|
min_transformers_version="4.55.1",
|
||||||
transformers_version_reason="HF model broken in 4.55.0",
|
transformers_version_reason="HF model broken in 4.55.0",
|
||||||
),
|
),
|
||||||
|
|||||||
@ -173,6 +173,7 @@ def cp_lse_ag_out_rs(
|
|||||||
cp_attn_lse: torch.Tensor,
|
cp_attn_lse: torch.Tensor,
|
||||||
cp_group: GroupCoordinator,
|
cp_group: GroupCoordinator,
|
||||||
ctx: CPTritonContext = None,
|
ctx: CPTritonContext = None,
|
||||||
|
return_lse=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
cp_attn_out: [ B, H, D ]
|
cp_attn_out: [ B, H, D ]
|
||||||
@ -192,8 +193,15 @@ def cp_lse_ag_out_rs(
|
|||||||
|
|
||||||
cp_attn_lse = cp_attn_lse.contiguous()
|
cp_attn_lse = cp_attn_lse.contiguous()
|
||||||
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
lses = cp_group.all_gather(cp_attn_lse, dim=0).view_as(lses)
|
||||||
out, _ = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
out, lse = correct_attn_out(cp_attn_out, lses, cp_group.rank_in_group, ctx)
|
||||||
|
assert out.is_contiguous()
|
||||||
out = cp_group.reduce_scatter(out, dim=1)
|
out = cp_group.reduce_scatter(out, dim=1)
|
||||||
|
|
||||||
|
if return_lse:
|
||||||
|
cp_num_heads = lse.shape[1] // cp_group.world_size
|
||||||
|
cp_rank = cp_group.rank_in_group
|
||||||
|
lse = lse[:, cp_num_heads * cp_rank : cp_num_heads * (cp_rank + 1)]
|
||||||
|
return out, lse
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1202,6 +1202,23 @@ class ModelConfig:
|
|||||||
"Supported models implement the `SupportsPP` interface."
|
"Supported models implement the `SupportsPP` interface."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
decode_context_parallel_size = parallel_config.decode_context_parallel_size
|
||||||
|
if decode_context_parallel_size > 1 and not self.use_mla:
|
||||||
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||||
|
assert tensor_parallel_size > total_num_kv_heads, (
|
||||||
|
f"tensor parallel size {tensor_parallel_size} must be greater "
|
||||||
|
f"than total num kv heads {total_num_kv_heads} when enable "
|
||||||
|
f"decode context parallel for GQA/MQA"
|
||||||
|
)
|
||||||
|
|
||||||
|
max_dcp_size = tensor_parallel_size // total_num_kv_heads
|
||||||
|
assert decode_context_parallel_size <= max_dcp_size, (
|
||||||
|
f"decode context parallel size must less than or equal to "
|
||||||
|
f"(tensor parallel size {tensor_parallel_size} // total "
|
||||||
|
f"num kv heads {total_num_kv_heads}) = {max_dcp_size}, "
|
||||||
|
f"but got {decode_context_parallel_size}"
|
||||||
|
)
|
||||||
|
|
||||||
def get_sliding_window(self) -> int | None:
|
def get_sliding_window(self) -> int | None:
|
||||||
"""Get the sliding window size from the HF text config if present."""
|
"""Get the sliding window size from the HF text config if present."""
|
||||||
return getattr(self.hf_text_config, "sliding_window", None)
|
return getattr(self.hf_text_config, "sliding_window", None)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from vllm.attention.backends.abstract import (
|
|||||||
is_quantized_kv_cache,
|
is_quantized_kv_cache,
|
||||||
)
|
)
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||||
from vllm.attention.utils.fa_utils import (
|
from vllm.attention.utils.fa_utils import (
|
||||||
flash_attn_supports_fp8,
|
flash_attn_supports_fp8,
|
||||||
@ -32,6 +33,7 @@ if is_flash_attn_varlen_func_available():
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
from vllm.v1.attention.backends.utils import (
|
from vllm.v1.attention.backends.utils import (
|
||||||
@ -147,6 +149,10 @@ class FlashAttentionMetadata:
|
|||||||
prefix_kv_lens: torch.Tensor | None
|
prefix_kv_lens: torch.Tensor | None
|
||||||
suffix_kv_lens: torch.Tensor | None
|
suffix_kv_lens: torch.Tensor | None
|
||||||
|
|
||||||
|
# For GQA DCP
|
||||||
|
max_dcp_context_kv_len: int | None = None
|
||||||
|
dcp_context_kv_lens: torch.Tensor | None = None
|
||||||
|
|
||||||
# Optional aot scheduling
|
# Optional aot scheduling
|
||||||
scheduler_metadata: torch.Tensor | None = None
|
scheduler_metadata: torch.Tensor | None = None
|
||||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||||
@ -216,6 +222,16 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||||
self.aot_schedule = get_flash_attn_version() == 3
|
self.aot_schedule = get_flash_attn_version() == 3
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.distributed.parallel_state import get_dcp_group
|
||||||
|
|
||||||
|
self.dcp_world_size = get_dcp_group().world_size
|
||||||
|
self.dcp_rank = get_dcp_group().rank_in_group
|
||||||
|
except AssertionError:
|
||||||
|
# DCP might not be initialized in testing
|
||||||
|
self.dcp_world_size = 1
|
||||||
|
self.dcp_rank = 0
|
||||||
|
|
||||||
self.use_full_cuda_graph = (
|
self.use_full_cuda_graph = (
|
||||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||||
)
|
)
|
||||||
@ -306,7 +322,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
max_seqlen_q=max_query_len,
|
max_seqlen_q=max_query_len,
|
||||||
max_seqlen_k=max_seq_len,
|
max_seqlen_k=max_seq_len,
|
||||||
num_heads_q=self.num_heads_q,
|
num_heads_q=self.num_heads_q * self.dcp_world_size,
|
||||||
num_heads_kv=self.num_heads_kv,
|
num_heads_kv=self.num_heads_kv,
|
||||||
headdim=self.headdim,
|
headdim=self.headdim,
|
||||||
cache_seqlens=seqlens,
|
cache_seqlens=seqlens,
|
||||||
@ -320,8 +336,35 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
use_cascade = common_prefix_len > 0
|
use_cascade = common_prefix_len > 0
|
||||||
|
max_dcp_context_kv_len = 0
|
||||||
|
dcp_context_kv_lens = None
|
||||||
|
|
||||||
if use_cascade:
|
cu_prefix_query_lens = None
|
||||||
|
prefix_kv_lens = None
|
||||||
|
suffix_kv_lens = None
|
||||||
|
prefix_scheduler_metadata = None
|
||||||
|
|
||||||
|
if self.dcp_world_size > 1:
|
||||||
|
query_kv_lens_cpu = (
|
||||||
|
common_attn_metadata.query_start_loc_cpu[1:]
|
||||||
|
- common_attn_metadata.query_start_loc_cpu[:-1]
|
||||||
|
)
|
||||||
|
dcp_context_kv_lens_cpu = seq_lens_cpu - query_kv_lens_cpu
|
||||||
|
dcp_context_kv_lens_cpu = dcp_context_kv_lens_cpu // self.dcp_world_size + (
|
||||||
|
self.dcp_rank <= (dcp_context_kv_lens_cpu - 1) % self.dcp_world_size
|
||||||
|
)
|
||||||
|
dcp_context_kv_lens = dcp_context_kv_lens_cpu.to(self.device)
|
||||||
|
max_dcp_context_kv_len = dcp_context_kv_lens.max().item()
|
||||||
|
|
||||||
|
scheduler_metadata = schedule(
|
||||||
|
batch_size=num_reqs,
|
||||||
|
cu_query_lens=query_start_loc,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
seqlens=dcp_context_kv_lens,
|
||||||
|
max_seq_len=max_dcp_context_kv_len,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
elif use_cascade:
|
||||||
cu_prefix_query_lens = torch.tensor(
|
cu_prefix_query_lens = torch.tensor(
|
||||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
@ -348,10 +391,6 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
causal=True,
|
causal=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
cu_prefix_query_lens = None
|
|
||||||
prefix_kv_lens = None
|
|
||||||
suffix_kv_lens = None
|
|
||||||
prefix_scheduler_metadata = None
|
|
||||||
scheduler_metadata = schedule(
|
scheduler_metadata = schedule(
|
||||||
batch_size=num_reqs,
|
batch_size=num_reqs,
|
||||||
cu_query_lens=query_start_loc,
|
cu_query_lens=query_start_loc,
|
||||||
@ -379,6 +418,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
|
max_dcp_context_kv_len=max_dcp_context_kv_len,
|
||||||
|
dcp_context_kv_lens=dcp_context_kv_lens,
|
||||||
use_cascade=use_cascade,
|
use_cascade=use_cascade,
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
scheduler_metadata=scheduler_metadata,
|
scheduler_metadata=scheduler_metadata,
|
||||||
@ -396,6 +437,8 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
|
|||||||
|
|
||||||
|
|
||||||
class FlashAttentionImpl(AttentionImpl):
|
class FlashAttentionImpl(AttentionImpl):
|
||||||
|
can_return_lse_for_decode: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
@ -562,30 +605,45 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||||
|
|
||||||
flash_attn_varlen_func(
|
if self.dcp_world_size > 1:
|
||||||
q=query[:num_actual_tokens],
|
self._forward_with_dcp(
|
||||||
k=key_cache,
|
query[:num_actual_tokens],
|
||||||
v=value_cache,
|
key[:num_actual_tokens],
|
||||||
out=output[:num_actual_tokens],
|
value[:num_actual_tokens],
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
key_cache,
|
||||||
max_seqlen_q=max_seqlen_q,
|
value_cache,
|
||||||
seqused_k=seqused_k,
|
output[:num_actual_tokens],
|
||||||
max_seqlen_k=max_seqlen_k,
|
attn_metadata,
|
||||||
softmax_scale=self.scale,
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
causal=attn_metadata.causal,
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
alibi_slopes=self.alibi_slopes,
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
window_size=self.sliding_window,
|
)
|
||||||
block_table=block_table,
|
return output
|
||||||
softcap=self.logits_soft_cap,
|
else:
|
||||||
scheduler_metadata=scheduler_metadata,
|
flash_attn_varlen_func(
|
||||||
fa_version=self.vllm_flash_attn_version,
|
q=query[:num_actual_tokens],
|
||||||
q_descale=layer._q_scale.expand(descale_shape),
|
k=key_cache,
|
||||||
k_descale=layer._k_scale.expand(descale_shape),
|
v=value_cache,
|
||||||
v_descale=layer._v_scale.expand(descale_shape),
|
out=output[:num_actual_tokens],
|
||||||
num_splits=attn_metadata.max_num_splits,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
s_aux=self.sinks,
|
max_seqlen_q=max_seqlen_q,
|
||||||
)
|
seqused_k=seqused_k,
|
||||||
return output
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=attn_metadata.causal,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
scheduler_metadata=scheduler_metadata,
|
||||||
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=layer._q_scale.expand(descale_shape),
|
||||||
|
k_descale=layer._k_scale.expand(descale_shape),
|
||||||
|
v_descale=layer._v_scale.expand(descale_shape),
|
||||||
|
num_splits=attn_metadata.max_num_splits,
|
||||||
|
s_aux=self.sinks,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
# Cascade attention (rare case).
|
# Cascade attention (rare case).
|
||||||
cascade_attention(
|
cascade_attention(
|
||||||
@ -615,6 +673,86 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _forward_with_dcp(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
attn_metadata: FlashAttentionMetadata,
|
||||||
|
q_descale: torch.Tensor | None = None,
|
||||||
|
k_descale: torch.Tensor | None = None,
|
||||||
|
v_descale: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
cu_seqlens_q = attn_metadata.query_start_loc
|
||||||
|
max_seqlen_q = attn_metadata.max_query_len
|
||||||
|
block_table = attn_metadata.block_table
|
||||||
|
|
||||||
|
query = query.contiguous()
|
||||||
|
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
|
||||||
|
context_attn_out, context_lse = flash_attn_varlen_func(
|
||||||
|
q=query_across_dcp,
|
||||||
|
k=key_cache,
|
||||||
|
v=value_cache,
|
||||||
|
out=None,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
seqused_k=attn_metadata.dcp_context_kv_lens,
|
||||||
|
max_seqlen_k=attn_metadata.max_dcp_context_kv_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
block_table=block_table,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||||
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
)
|
||||||
|
# FA returns LSE in shape [ H, B ] but cp_lse_ag_out_rs wants [ B, H ]
|
||||||
|
context_attn_out_cor, context_lse_cor = cp_lse_ag_out_rs(
|
||||||
|
context_attn_out,
|
||||||
|
context_lse.transpose(0, 1),
|
||||||
|
get_dcp_group(),
|
||||||
|
return_lse=True,
|
||||||
|
)
|
||||||
|
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
|
||||||
|
|
||||||
|
query_attn_out, query_lse = flash_attn_varlen_func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
out=None,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_q,
|
||||||
|
max_seqlen_k=max_seqlen_q,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=attn_metadata.causal,
|
||||||
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
window_size=self.sliding_window,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
fa_version=self.vllm_flash_attn_version,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
)
|
||||||
|
assert context_attn_out_cor.shape == query_attn_out.shape
|
||||||
|
assert context_lse_cor.shape == query_lse.shape
|
||||||
|
merge_attn_states(
|
||||||
|
output,
|
||||||
|
context_attn_out_cor,
|
||||||
|
context_lse_cor,
|
||||||
|
query_attn_out,
|
||||||
|
query_lse,
|
||||||
|
)
|
||||||
|
|
||||||
def _forward_encoder_attention(
|
def _forward_encoder_attention(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -684,6 +822,7 @@ def use_cascade_attention(
|
|||||||
use_sliding_window: bool,
|
use_sliding_window: bool,
|
||||||
use_local_attention: bool,
|
use_local_attention: bool,
|
||||||
num_sms: int,
|
num_sms: int,
|
||||||
|
dcp_world_size: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Decide whether to use cascade attention.
|
"""Decide whether to use cascade attention.
|
||||||
|
|
||||||
@ -705,6 +844,9 @@ def use_cascade_attention(
|
|||||||
num_reqs = len(query_lens)
|
num_reqs = len(query_lens)
|
||||||
if num_reqs < 8:
|
if num_reqs < 8:
|
||||||
return False
|
return False
|
||||||
|
# disable cascade attention for DCP
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
return False
|
||||||
|
|
||||||
# Heuristics to decide whether using cascade attention is beneficial.
|
# Heuristics to decide whether using cascade attention is beneficial.
|
||||||
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
# 1. When FlashDecoding is not used for normal attention, cascade attention
|
||||||
|
|||||||
@ -345,6 +345,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
|
|||||||
use_sliding_window: bool,
|
use_sliding_window: bool,
|
||||||
use_local_attention: bool,
|
use_local_attention: bool,
|
||||||
num_sms: int,
|
num_sms: int,
|
||||||
|
dcp_world_size: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|||||||
@ -1523,6 +1523,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
use_sliding_window=use_sliding_window,
|
use_sliding_window=use_sliding_window,
|
||||||
use_local_attention=use_local_attention,
|
use_local_attention=use_local_attention,
|
||||||
num_sms=self.num_sms,
|
num_sms=self.num_sms,
|
||||||
|
dcp_world_size=self.dcp_world_size,
|
||||||
)
|
)
|
||||||
return common_prefix_len if use_cascade else 0
|
return common_prefix_len if use_cascade else 0
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user