mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 10:47:02 +08:00
add support for multiple builders in the model runner
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
6b0c303ab4
commit
5bbfd95bdb
@ -53,6 +53,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
super().__init__()
|
||||
|
||||
self.buffers = buffers
|
||||
for buffer in self.buffers:
|
||||
buffer.set_num_sms(4)
|
||||
self.max_tokens_per_rank = max_tokens_per_rank
|
||||
self.use_fp8_dispatch = use_fp8_dispatch
|
||||
# The dispatch function returns a handle that the combine function
|
||||
|
||||
@ -444,6 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
self.chunked_prefill_workspace_size = scheduler_config.max_num_seqs * cache_config.block_size
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * cache_config.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
@ -566,8 +567,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
decode_threshold=1)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
ubatch_id: Optional[int] = None):
|
||||
seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens,
|
||||
@ -600,8 +600,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
ubatch_id: Optional[int] = None) -> M:
|
||||
fast_build: bool = False) -> M:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
@ -724,7 +723,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
decode_metadata = self._build_decode(
|
||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=seq_lens[:num_decodes],
|
||||
ubatch_id=ubatch_id
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
||||
@ -65,15 +65,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = [None, None]
|
||||
self.cg_buf_num_splits = [None, None]
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata:
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
# print(f"UBATCH ID: {ubatch_id}")
|
||||
ubatch_id = 0 if ubatch_id is None else ubatch_id
|
||||
assert ubatch_id < 2
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
@ -85,28 +82,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
# if False:
|
||||
n = num_splits.size(0)
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_num_splits[ubatch_id] is None:
|
||||
if self.cg_buf_num_splits is None:
|
||||
# logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}")
|
||||
self.cg_buf_num_splits[ubatch_id] = num_splits
|
||||
self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata
|
||||
elif n <= self.cg_buf_num_splits[ubatch_id].size(0):
|
||||
assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None
|
||||
self.cg_buf_num_splits = num_splits
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
elif n <= self.cg_buf_num_splits.size(0):
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() ==
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata[ubatch_id].\
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id]
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits[ubatch_id].size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[ubatch_id][:n]
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
|
||||
@ -197,7 +197,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
# Outer index is kv cache index, inner index is ubatch id
|
||||
self.attn_metadata_builders: list[list[AttentionMetadataBuilder]] = []
|
||||
self.attn_backends: list[type[AttentionBackend]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
|
||||
@ -361,6 +362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||
|
||||
def get_builder(self, index: int, ubatch_id: Optional[int] = None) -> AttentionMetadataBuilder:
|
||||
if ubatch_id is None:
|
||||
return self.attn_metadata_builders[index][0]
|
||||
else:
|
||||
return self.attn_metadata_builders[index][ubatch_id]
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
@ -379,7 +386,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||
return
|
||||
|
||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
||||
self.get_builder(0).reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
|
||||
# For models with multiple KV cache groups, the groups should agree on
|
||||
@ -390,7 +397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# re-order the batch (not only the first).
|
||||
# TODO(tdoublep): verify this during engine init instead of at runtime
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
||||
batch_reordered = self.get_builder(i).reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
assert not batch_reordered
|
||||
|
||||
@ -943,7 +950,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
# Prepare for cascade attention if enabled & beneficial.
|
||||
common_prefix_len = 0
|
||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
||||
builder = self.get_builder(kv_cache_group_id)
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
@ -960,10 +967,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata_list):
|
||||
assert common_attn_metadata.max_query_len == 1
|
||||
attn_metadata_i = (
|
||||
builder.build(
|
||||
self.get_builder(kv_cache_group_id, ubatch_id=ubid).build(
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
ubatch_id=ubid))
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
@ -1023,8 +1029,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
attn_metadata[layer_name] = local_attn_metadata_i
|
||||
|
||||
attention_cuda_graphs = all(
|
||||
b.can_run_in_cudagraph(common_attn_metadata)
|
||||
for b in self.attn_metadata_builders)
|
||||
builder_list[0].can_run_in_cudagraph(common_attn_metadata)
|
||||
for builder_list in self.attn_metadata_builders)
|
||||
|
||||
# Hot-Swap lora model
|
||||
if self.lora_config:
|
||||
@ -1897,27 +1903,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
is_dummy_run=is_dummy_run)
|
||||
if num_scheduled_tokens not in self.cudagraphs \
|
||||
and not skip_cuda_graphs and build_cuda_graph:
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"CAPTURING {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"CAPTURING {num_scheduled_tokens}")
|
||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||
# assert False
|
||||
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||
cudagraph_metadata.cudagraph.replay()
|
||||
return cudagraph_metadata.outputs
|
||||
else:
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||
return self._run_ubatches(ubatch_metadata, self.model)
|
||||
# run normal batch
|
||||
else:
|
||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||
self.model_inputs(slice(0, num_scheduled_tokens),
|
||||
scheduler_output, is_dummy_run)
|
||||
# if is_global_first_rank():
|
||||
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||
if is_global_first_rank():
|
||||
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
||||
with set_forward_context(attn_metadata,
|
||||
vllm_config=self.vllm_config,
|
||||
@ -2725,18 +2731,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
)
|
||||
for ubid, common_attn_metadata in enumerate(common_attn_metadata_list):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].
|
||||
self.get_builder(kv_cache_group_id, ubatch_id=ubid).
|
||||
build(
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
ubatch_id=ubid
|
||||
common_attn_metadata=common_attn_metadata
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
assert type(attn_metadata) is list
|
||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||
else:
|
||||
attn_metadata_i = self.attn_metadata_builders[
|
||||
kv_cache_group_id].build_for_cudagraph_capture(
|
||||
attn_metadata_i = self.get_builder(
|
||||
kv_cache_group_id).build_for_cudagraph_capture(
|
||||
common_attn_metadata)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@ -3071,7 +3076,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _initialize_single_attn_backend(
|
||||
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
||||
) -> tuple[type[AttentionBackend], list[AttentionMetadataBuilder]]:
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
@ -3098,12 +3103,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
builders = []
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
builders.append(attn_metadata_builder_i)
|
||||
if self.parallel_config.enable_microbatching:
|
||||
attn_metadata_builder_2 = attn_backend_i.get_builder_cls()(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
)
|
||||
builders.append(attn_metadata_builder_2)
|
||||
|
||||
|
||||
if (self.full_cuda_graph
|
||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||
@ -3111,7 +3127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
f"Full CUDAGraph not supported for "
|
||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||
f"full_cuda_graph or use a different attention backend.")
|
||||
return attn_backend_i, attn_metadata_builder_i
|
||||
return attn_backend_i, builders
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
@ -3124,11 +3140,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
|
||||
attn_backend_i, attn_metadata_builder_i = (
|
||||
attn_backend_i, attn_metadata_builders = (
|
||||
self._initialize_single_attn_backend(
|
||||
kv_cache_spec, kv_cache_group_spec.layer_names))
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builders)
|
||||
|
||||
if len(self.attn_backends) > 0:
|
||||
return
|
||||
@ -3157,11 +3173,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
assert len(attn_specs) == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
|
||||
attn_backend, attn_metadata_builder = (
|
||||
attn_backend, attn_metadata_builders = (
|
||||
self._initialize_single_attn_backend(attn_specs[0],
|
||||
attn_layers.keys()))
|
||||
self.attn_backends.append(attn_backend)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
||||
self.attn_metadata_builders.append(attn_metadata_builders)
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user