diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 76e17f3797a1a..37ec0fb97e06b 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -84,7 +84,7 @@ class BlockTable: self.pcp_world_size = get_pcp_group().world_size self.pcp_rank = get_pcp_group().rank_in_group except AssertionError: - # DCP might not be initialized in testing + # PCP might not be initialized in testing self.pcp_world_size = 1 self.pcp_rank = 0 try: @@ -268,6 +268,11 @@ class MultiGroupBlockTable: # (max_model_len//dcp_world_size) tokens in kvcache, # so the block_size which used for calc max_num_blocks_per_req # must be multiplied by dcp_world_size. + try: + pcp_world_size = get_pcp_group().world_size + except AssertionError: + # PCP might not be initialized in testing + pcp_world_size = 1 try: dcp_world_size = get_dcp_group().world_size except AssertionError: @@ -280,12 +285,14 @@ class MultiGroupBlockTable: f"must match block_sizes length ({len(block_sizes)})" ) + total_cp_world_size = dcp_world_size * pcp_world_size + self.block_tables = [ BlockTable( block_size, max_num_reqs, max( - cdiv(max_model_len, block_size * dcp_world_size), + cdiv(max_model_len, block_size * total_cp_world_size), 1 + num_speculative_tokens, ), max_num_batched_tokens,