mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:25:00 +08:00
[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)
Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
parent
41b92f7d38
commit
4516d44b7f
@ -39,6 +39,7 @@ class ParallelSetup(NamedTuple):
|
||||
class CPTestOptions(NamedTuple):
|
||||
multi_node_only: bool
|
||||
load_format: str | None = None
|
||||
attn_backend: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -58,6 +59,7 @@ class CPTestSettings:
|
||||
multi_node_only: bool = False,
|
||||
runner: RunnerOption = "auto",
|
||||
load_format: str | None = None,
|
||||
attn_backend: str | None = None,
|
||||
):
|
||||
parallel_setups = []
|
||||
for eager_mode_val in [False]:
|
||||
@ -79,7 +81,9 @@ class CPTestSettings:
|
||||
distributed_backends=["mp"],
|
||||
runner=runner,
|
||||
test_options=CPTestOptions(
|
||||
multi_node_only=multi_node_only, load_format=load_format
|
||||
multi_node_only=multi_node_only,
|
||||
load_format=load_format,
|
||||
attn_backend=attn_backend,
|
||||
),
|
||||
)
|
||||
|
||||
@ -117,7 +121,7 @@ def _compare_cp_with_tp(
|
||||
chunked_prefill,
|
||||
) = parallel_setup
|
||||
|
||||
multi_node_only, load_format = test_options
|
||||
multi_node_only, load_format, attn_backend = test_options
|
||||
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
@ -177,6 +181,13 @@ def _compare_cp_with_tp(
|
||||
if hf_overrides:
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
if not attn_backend:
|
||||
cp_env = tp_env = {}
|
||||
else:
|
||||
cp_env = tp_env = {
|
||||
"VLLM_ATTENTION_BACKEND": attn_backend,
|
||||
}
|
||||
|
||||
cp_args = [
|
||||
*common_args,
|
||||
"--tensor-parallel-size",
|
||||
@ -205,6 +216,8 @@ def _compare_cp_with_tp(
|
||||
model_id,
|
||||
cp_args,
|
||||
tp_args,
|
||||
cp_env,
|
||||
tp_env,
|
||||
method=method,
|
||||
max_wait_seconds=720,
|
||||
)
|
||||
|
||||
@ -1183,6 +1183,14 @@ class ModelConfig:
|
||||
f"but got {decode_context_parallel_size}"
|
||||
)
|
||||
|
||||
num_q_per_kv = total_num_attention_heads // total_num_kv_heads
|
||||
assert num_q_per_kv % decode_context_parallel_size == 0, (
|
||||
f"Total number of q per kv attn heads ({num_q_per_kv})"
|
||||
" must be divisible by dcp world size when enable "
|
||||
"decode context parallel for GQA "
|
||||
f"({parallel_config.decode_context_parallel_size})."
|
||||
)
|
||||
|
||||
def get_sliding_window(self) -> int | None:
|
||||
"""Get the sliding window size from the HF text config if present."""
|
||||
return getattr(self.hf_text_config, "sliding_window", None)
|
||||
|
||||
@ -259,6 +259,7 @@ def use_trtllm_attention(
|
||||
num_kv_heads: int,
|
||||
num_tokens: int,
|
||||
max_seq_len: int,
|
||||
dcp_world_size: int,
|
||||
kv_cache_dtype: str,
|
||||
q_dtype: torch.dtype,
|
||||
is_prefill: bool,
|
||||
@ -272,6 +273,14 @@ def use_trtllm_attention(
|
||||
if force_use_trtllm is not None and not force_use_trtllm:
|
||||
return False
|
||||
|
||||
# Decode context parallel is not supported
|
||||
if dcp_world_size > 1:
|
||||
logger.warning_once(
|
||||
"Trtllm does not support returning LSE and as a result "
|
||||
"does not support DCP, reverting to FlashInfer"
|
||||
)
|
||||
return False
|
||||
|
||||
# The platform is not supported
|
||||
if not supports_trtllm_attention():
|
||||
if force_use_trtllm:
|
||||
|
||||
@ -10,6 +10,7 @@ import torch
|
||||
from flashinfer import (
|
||||
BatchDecodeWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
MultiLevelCascadeAttentionWrapper,
|
||||
)
|
||||
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
|
||||
@ -24,8 +25,11 @@ from vllm.attention.backends.abstract import (
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.common import cp_lse_ag_out_rs
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
@ -50,6 +54,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
KVCacheLayoutType,
|
||||
get_dcp_local_seq_lens,
|
||||
get_kv_cache_layout,
|
||||
get_per_layer_parameters,
|
||||
infer_global_hyperparameters,
|
||||
@ -160,6 +165,113 @@ def trtllm_prefill_attn_kvfp8_dequant(
|
||||
return mock_kv_cache, mock_block_table
|
||||
|
||||
|
||||
class BatchDCPPrefillWrapper:
|
||||
def __init__(
|
||||
self,
|
||||
workspace_buffer: torch.Tensor | None = None,
|
||||
):
|
||||
self._context = BatchPrefillWithPagedKVCacheWrapper(
|
||||
workspace_buffer, get_kv_cache_layout()
|
||||
)
|
||||
self._new_tokens = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
workspace_buffer, get_kv_cache_layout()
|
||||
)
|
||||
|
||||
def plan(
|
||||
self,
|
||||
qo_indptr_cpu: torch.Tensor,
|
||||
paged_kv_indptr_cpu: torch.Tensor,
|
||||
paged_kv_indices: torch.Tensor,
|
||||
paged_kv_last_page_len_cpu: torch.Tensor,
|
||||
prefill_start: int,
|
||||
page_size: int,
|
||||
num_qo_heads: int,
|
||||
dcp_world_size: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
sm_scale: float,
|
||||
window_left: int,
|
||||
logits_soft_cap: float | None,
|
||||
q_data_type: torch.dtype,
|
||||
kv_cache_dtype: torch.dtype,
|
||||
prefill_fixed_split_size: int,
|
||||
disable_split_kv: bool,
|
||||
):
|
||||
"""Plan the prefill operation with given parameters."""
|
||||
self._context.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu[prefill_start:],
|
||||
num_qo_heads * dcp_world_size,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
causal=False, # This is context run
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
kv_data_type=kv_cache_dtype,
|
||||
fixed_split_size=prefill_fixed_split_size,
|
||||
disable_split_kv=disable_split_kv,
|
||||
)
|
||||
self._new_tokens.plan(
|
||||
qo_indptr=qo_indptr_cpu,
|
||||
kv_indptr=qo_indptr_cpu,
|
||||
num_qo_heads=num_qo_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim_qk=head_dim,
|
||||
head_dim_vo=head_dim,
|
||||
causal=True, # This is newtokens run
|
||||
sm_scale=sm_scale,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
q_data_type=q_data_type,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefill_query: torch.Tensor,
|
||||
kv_cache_permute: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
):
|
||||
prefill_query_across_dcp = get_dcp_group().all_gather(
|
||||
prefill_query.contiguous(), dim=1
|
||||
)
|
||||
output_context_tmp, lse_context_tmp = self._context.run(
|
||||
prefill_query_across_dcp,
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
return_lse=True,
|
||||
)
|
||||
output_context, lse_context = cp_lse_ag_out_rs(
|
||||
output_context_tmp, lse_context_tmp, get_dcp_group(), return_lse=True
|
||||
)
|
||||
lse_context = lse_context.transpose(0, 1).contiguous()
|
||||
|
||||
output_query, lse_query = self._new_tokens.run(
|
||||
prefill_query,
|
||||
key,
|
||||
value,
|
||||
return_lse=True,
|
||||
)
|
||||
lse_query = lse_query.transpose(0, 1).contiguous()
|
||||
|
||||
merge_attn_states(
|
||||
out,
|
||||
output_context,
|
||||
lse_context,
|
||||
output_query,
|
||||
lse_query,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class FlashInferBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
@ -281,7 +393,9 @@ class FlashInferMetadata:
|
||||
# For cascade attention (CPU for planning).
|
||||
use_cascade: bool
|
||||
|
||||
prefill_wrapper: BatchPrefillWithPagedKVCacheWrapper | None = None
|
||||
prefill_wrapper: (
|
||||
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
|
||||
) = None
|
||||
decode_wrapper: BatchDecodeWithPagedKVCacheWrapper | None = None
|
||||
cascade_wrapper: MultiLevelCascadeAttentionWrapper | None = None
|
||||
|
||||
@ -303,7 +417,9 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self._workspace_buffer = None
|
||||
self._prefill_wrapper = None # Wrapper for prefill/append
|
||||
self._prefill_wrapper: (
|
||||
BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper | None
|
||||
) = None # Wrapper for prefill/append
|
||||
self._decode_wrapper = None # Wrapper for decode (general shape)
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
@ -341,9 +457,23 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.num_qo_heads = self.model_config.get_num_attention_heads(
|
||||
self.vllm_config.parallel_config
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
self.dcp_kv_cache_interleave_size = (
|
||||
vllm_config.parallel_config.dcp_kv_cache_interleave_size
|
||||
)
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = 1
|
||||
|
||||
self.num_qo_heads = (
|
||||
self.model_config.get_num_attention_heads(self.vllm_config.parallel_config)
|
||||
* self.dcp_world_size
|
||||
)
|
||||
|
||||
self.num_kv_heads = self.kv_cache_spec.num_kv_heads
|
||||
self.head_dim = self.kv_cache_spec.head_size
|
||||
self.page_size = self.kv_cache_spec.block_size
|
||||
@ -455,11 +585,19 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
)
|
||||
return self._workspace_buffer
|
||||
|
||||
def _get_prefill_wrapper(self):
|
||||
def _get_prefill_wrapper(
|
||||
self,
|
||||
) -> BatchPrefillWithPagedKVCacheWrapper | BatchDCPPrefillWrapper:
|
||||
if self._prefill_wrapper is None:
|
||||
if self.dcp_world_size > 1:
|
||||
self._prefill_wrapper = BatchDCPPrefillWrapper(
|
||||
workspace_buffer=self._get_workspace_buffer(),
|
||||
)
|
||||
else:
|
||||
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
|
||||
self._get_workspace_buffer(), get_kv_cache_layout()
|
||||
)
|
||||
assert self._prefill_wrapper is not None
|
||||
return self._prefill_wrapper
|
||||
|
||||
def _get_decode_wrapper(self, batch_size: int, use_cudagraph: bool = False):
|
||||
@ -526,9 +664,29 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
if num_prefills > 0:
|
||||
qo_indptr_prefill_cpu = (
|
||||
qo_indptr_cpu[num_decodes:] - qo_indptr_cpu[num_decodes]
|
||||
)
|
||||
query_lens_prefill_cpu = (
|
||||
qo_indptr_prefill_cpu[1:] - qo_indptr_prefill_cpu[:-1]
|
||||
)
|
||||
seq_lens_cpu[num_decodes:] = (
|
||||
seq_lens_cpu[num_decodes:] - query_lens_prefill_cpu
|
||||
)
|
||||
|
||||
seq_lens_cpu = get_dcp_local_seq_lens(
|
||||
seq_lens_cpu,
|
||||
self.dcp_world_size,
|
||||
self.dcp_rank,
|
||||
self.dcp_kv_cache_interleave_size,
|
||||
)
|
||||
|
||||
seq_lens_np = seq_lens_cpu.numpy()
|
||||
num_blocks_np = (seq_lens_np + (page_size - 1)) // page_size
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
@ -589,7 +747,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
# write self.paged_kv_last_page_len_cpu inplace
|
||||
paged_kv_last_page_len_np = seq_lens_np % page_size
|
||||
self.paged_kv_last_page_len_np[:num_reqs] = np.where(
|
||||
paged_kv_last_page_len_np == 0,
|
||||
(paged_kv_last_page_len_np == 0) & (seq_lens_np != 0),
|
||||
page_size,
|
||||
paged_kv_last_page_len_np,
|
||||
)
|
||||
@ -600,13 +758,16 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
self.num_kv_heads,
|
||||
num_prefill_tokens,
|
||||
max_seq_len,
|
||||
self.dcp_world_size,
|
||||
self.cache_dtype,
|
||||
self.q_data_type,
|
||||
is_prefill=True,
|
||||
has_sinks=self.has_sinks,
|
||||
has_spec=uses_spec_reorder,
|
||||
)
|
||||
decode_use_trtllm = self.use_trtllm_decode_attention
|
||||
decode_use_trtllm = (
|
||||
self.use_trtllm_decode_attention and self.dcp_world_size <= 1
|
||||
)
|
||||
|
||||
if not (prefill_use_trtllm and decode_use_trtllm):
|
||||
if self.has_sinks:
|
||||
@ -651,7 +812,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
use_cascade=use_cascade,
|
||||
)
|
||||
|
||||
qo_indptr_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
paged_kv_indptr_cpu = self.paged_kv_indptr_cpu[: 1 + num_reqs]
|
||||
paged_kv_last_page_len_cpu = self.paged_kv_last_page_len_cpu[:num_reqs]
|
||||
|
||||
@ -703,6 +863,34 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
attn_metadata.max_q_len_prefill = int(query_lens_prefill.max().item())
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
if self.dcp_world_size > 1:
|
||||
assert isinstance(
|
||||
attn_metadata.prefill_wrapper, BatchDCPPrefillWrapper
|
||||
)
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu=qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu=paged_kv_indptr_cpu,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len_cpu=paged_kv_last_page_len_cpu,
|
||||
prefill_start=prefill_start,
|
||||
page_size=self.page_size,
|
||||
num_qo_heads=self.num_qo_heads,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
sm_scale=self.sm_scale,
|
||||
window_left=self.window_left,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
q_data_type=self.q_data_type,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
prefill_fixed_split_size=self.prefill_fixed_split_size,
|
||||
disable_split_kv=self.disable_split_kv,
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
attn_metadata.prefill_wrapper,
|
||||
BatchPrefillWithPagedKVCacheWrapper,
|
||||
)
|
||||
attn_metadata.prefill_wrapper.plan(
|
||||
qo_indptr_cpu,
|
||||
paged_kv_indptr_cpu,
|
||||
@ -770,7 +958,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
paged_kv_indices,
|
||||
self.paged_kv_last_page_len_cpu[:num_input_tokens],
|
||||
seq_lens_cpu[:num_input_tokens],
|
||||
self.num_qo_heads,
|
||||
self.num_qo_heads * self.dcp_world_size,
|
||||
self.num_kv_heads,
|
||||
self.head_dim,
|
||||
self.page_size,
|
||||
@ -797,6 +985,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
@ -989,6 +1179,8 @@ class FlashInferImpl(AttentionImpl):
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
query = query[:num_actual_tokens]
|
||||
key = key[:num_actual_tokens]
|
||||
value = value[:num_actual_tokens]
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens]
|
||||
|
||||
@ -1015,10 +1207,39 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert prefill_wrapper is not None
|
||||
|
||||
if not attn_metadata.prefill_use_trtllm:
|
||||
assert prefill_wrapper._causal
|
||||
if self.dcp_world_size > 1:
|
||||
assert isinstance(prefill_wrapper, BatchDCPPrefillWrapper)
|
||||
assert prefill_wrapper._context._window_left == self.window_left
|
||||
assert prefill_wrapper._context._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0
|
||||
)
|
||||
assert prefill_wrapper._context._sm_scale == self.scale
|
||||
assert not prefill_wrapper._context._causal
|
||||
assert prefill_wrapper._new_tokens._window_left == self.window_left
|
||||
assert prefill_wrapper._new_tokens._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0
|
||||
)
|
||||
assert prefill_wrapper._new_tokens._sm_scale == self.scale
|
||||
assert prefill_wrapper._new_tokens._causal
|
||||
|
||||
prefill_wrapper.run(
|
||||
layer,
|
||||
prefill_query,
|
||||
kv_cache_permute,
|
||||
key[num_decode_tokens:],
|
||||
value[num_decode_tokens:],
|
||||
out=output[num_decode_tokens:],
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
prefill_wrapper, BatchPrefillWithPagedKVCacheWrapper
|
||||
)
|
||||
assert prefill_wrapper._window_left == self.window_left
|
||||
assert prefill_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
||||
assert prefill_wrapper._logits_soft_cap == (
|
||||
self.logits_soft_cap or 0.0
|
||||
)
|
||||
assert prefill_wrapper._sm_scale == self.scale
|
||||
assert prefill_wrapper._causal
|
||||
prefill_wrapper.run(
|
||||
prefill_query,
|
||||
kv_cache_permute,
|
||||
@ -1101,6 +1322,30 @@ class FlashInferImpl(AttentionImpl):
|
||||
assert decode_wrapper._window_left == self.window_left
|
||||
assert decode_wrapper._logits_soft_cap == (self.logits_soft_cap or 0.0)
|
||||
assert decode_wrapper._sm_scale == self.scale
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
decode_query = get_dcp_group().all_gather(
|
||||
decode_query.contiguous(), dim=-2
|
||||
)
|
||||
output_tmp = torch.empty_like(decode_query)
|
||||
lse = torch.empty(
|
||||
(decode_query.size(0), decode_query.size(1)),
|
||||
dtype=torch.float32,
|
||||
device=decode_query.device,
|
||||
)
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache_permute,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
out=output_tmp,
|
||||
lse=lse,
|
||||
return_lse=True,
|
||||
)
|
||||
output[:num_decode_tokens] = cp_lse_ag_out_rs(
|
||||
output_tmp, lse, get_dcp_group()
|
||||
)
|
||||
else:
|
||||
decode_wrapper.run(
|
||||
decode_query,
|
||||
kv_cache_permute,
|
||||
|
||||
@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
|
||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
@ -726,6 +727,8 @@ class WorkerProc:
|
||||
pp_rank = get_pp_group().rank_in_group
|
||||
tp_size = get_tp_group().world_size
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
dcp_size = get_dcp_group().world_size
|
||||
dcp_rank = get_dcp_group().rank_in_group
|
||||
process_name = "Worker"
|
||||
if dp_size > 1:
|
||||
process_name += f"_DP{dp_rank}"
|
||||
@ -733,6 +736,8 @@ class WorkerProc:
|
||||
process_name += f"_PP{pp_rank}"
|
||||
if tp_size > 1:
|
||||
process_name += f"_TP{tp_rank}"
|
||||
if dcp_size > 1:
|
||||
process_name += f"_DCP{dcp_rank}"
|
||||
if enable_ep:
|
||||
ep_rank = get_ep_group().rank_in_group
|
||||
process_name += f"_EP{ep_rank}"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user