add support for multiple builders in the model runner

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-08-08 19:01:20 +00:00
parent 6b0c303ab4
commit 5bbfd95bdb
4 changed files with 63 additions and 50 deletions

View File

@ -53,6 +53,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
super().__init__()
self.buffers = buffers
for buffer in self.buffers:
buffer.set_num_sms(4)
self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch
# The dispatch function returns a handle that the combine function

View File

@ -444,6 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
self.chunked_prefill_workspace_size = scheduler_config.max_num_seqs * cache_config.block_size
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
@ -566,8 +567,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_threshold=1)
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
ubatch_id: Optional[int] = None):
seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens,
@ -600,8 +600,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
def build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
ubatch_id: Optional[int] = None) -> M:
fast_build: bool = False) -> M:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
max_query_len = common_attn_metadata.max_query_len
@ -724,7 +723,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens=seq_lens[:num_decodes],
ubatch_id=ubatch_id
)
attn_metadata = self.metadata_cls(

View File

@ -65,15 +65,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
self.cg_buf_tile_scheduler_metadata = [None, None]
self.cg_buf_num_splits = [None, None]
self.cg_buf_tile_scheduler_metadata = None
self.cg_buf_num_splits = None
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens: torch.Tensor,
ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata:
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
# print(f"UBATCH ID: {ubatch_id}")
ubatch_id = 0 if ubatch_id is None else ubatch_id
assert ubatch_id < 2
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
@ -85,28 +82,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
# if False:
n = num_splits.size(0)
# First time around (CUDAGraph capture), allocate the static buffer
if self.cg_buf_num_splits[ubatch_id] is None:
if self.cg_buf_num_splits is None:
# logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}")
self.cg_buf_num_splits[ubatch_id] = num_splits
self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits[ubatch_id].size(0):
assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None
self.cg_buf_num_splits = num_splits
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
elif n <= self.cg_buf_num_splits.size(0):
assert self.cg_buf_tile_scheduler_metadata is not None
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() ==
assert (self.cg_buf_tile_scheduler_metadata.size() ==
tile_scheduler_metadata.size())
self.cg_buf_tile_scheduler_metadata[ubatch_id].\
self.cg_buf_tile_scheduler_metadata.\
copy_(tile_scheduler_metadata)
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id]
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
# Num splits is per-batch, varying size (batch_size,)
n = num_splits.size(0)
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
# make sure static buffer is large enough
assert n <= self.cg_buf_num_splits[ubatch_id].size(0)
num_splits_view = self.cg_buf_num_splits[ubatch_id][:n]
assert n <= self.cg_buf_num_splits.size(0)
num_splits_view = self.cg_buf_num_splits[:n]
num_splits_view.copy_(num_splits)
self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
num_splits = num_splits_view
return FlashMLADecodeMetadata(

View File

@ -197,7 +197,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# self.model: nn.Module # Set after load_model
# Initialize in initialize_kv_cache
self.kv_caches: list[torch.Tensor] = []
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
# Outer index is kv cache index, inner index is ubatch id
self.attn_metadata_builders: list[list[AttentionMetadataBuilder]] = []
self.attn_backends: list[type[AttentionBackend]] = []
# self.kv_cache_config: KVCacheConfig
@ -361,6 +362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
self.max_num_tokens, dtype=torch.int32, device=self.device)
def get_builder(self, index: int, ubatch_id: Optional[int] = None) -> AttentionMetadataBuilder:
if ubatch_id is None:
return self.attn_metadata_builders[index][0]
else:
return self.attn_metadata_builders[index][ubatch_id]
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
"""
Update the order of requests in the batch based on the attention
@ -379,7 +386,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if len(self.kv_cache_config.kv_cache_groups) == 0:
return
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
self.get_builder(0).reorder_batch(self.input_batch,
scheduler_output)
# For models with multiple KV cache groups, the groups should agree on
@ -390,7 +397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# re-order the batch (not only the first).
# TODO(tdoublep): verify this during engine init instead of at runtime
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
batch_reordered = self.get_builder(i).reorder_batch(
self.input_batch, scheduler_output)
assert not batch_reordered
@ -943,7 +950,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len = 0
builder = self.attn_metadata_builders[kv_cache_group_id]
builder = self.get_builder(kv_cache_group_id)
if self.cascade_attn_enabled:
common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens,
@ -960,10 +967,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata_list):
assert common_attn_metadata.max_query_len == 1
attn_metadata_i = (
builder.build(
self.get_builder(kv_cache_group_id, ubatch_id=ubid).build(
common_prefix_len=common_prefix_len,
common_attn_metadata=common_attn_metadata,
ubatch_id=ubid))
common_attn_metadata=common_attn_metadata))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
@ -1023,8 +1029,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
attn_metadata[layer_name] = local_attn_metadata_i
attention_cuda_graphs = all(
b.can_run_in_cudagraph(common_attn_metadata)
for b in self.attn_metadata_builders)
builder_list[0].can_run_in_cudagraph(common_attn_metadata)
for builder_list in self.attn_metadata_builders)
# Hot-Swap lora model
if self.lora_config:
@ -1897,27 +1903,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
is_dummy_run=is_dummy_run)
if num_scheduled_tokens not in self.cudagraphs \
and not skip_cuda_graphs and build_cuda_graph:
# if is_global_first_rank():
# logger.info(f"CAPTURING {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"CAPTURING {num_scheduled_tokens}")
return self._capture_ubatches(ubatch_metadata, self.model)
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
# assert False
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
# if is_global_first_rank():
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
cudagraph_metadata.cudagraph.replay()
return cudagraph_metadata.outputs
else:
# if is_global_first_rank():
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
return self._run_ubatches(ubatch_metadata, self.model)
# run normal batch
else:
input_ids, positions, inputs_embeds, intermediate_tensors = \
self.model_inputs(slice(0, num_scheduled_tokens),
scheduler_output, is_dummy_run)
# if is_global_first_rank():
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
if is_global_first_rank():
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
skip_cuda_graphs = self.parallel_config.enable_microbatching
with set_forward_context(attn_metadata,
vllm_config=self.vllm_config,
@ -2725,18 +2731,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
for ubid, common_attn_metadata in enumerate(common_attn_metadata_list):
attn_metadata_i = (
self.attn_metadata_builders[kv_cache_group_id].
self.get_builder(kv_cache_group_id, ubatch_id=ubid).
build(
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
ubatch_id=ubid
common_attn_metadata=common_attn_metadata
))
for layer_name in kv_cache_group_spec.layer_names:
assert type(attn_metadata) is list
attn_metadata[ubid][layer_name] = attn_metadata_i
else:
attn_metadata_i = self.attn_metadata_builders[
kv_cache_group_id].build_for_cudagraph_capture(
attn_metadata_i = self.get_builder(
kv_cache_group_id).build_for_cudagraph_capture(
common_attn_metadata)
for layer_name in kv_cache_group_spec.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@ -3071,7 +3076,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def _initialize_single_attn_backend(
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
) -> tuple[type[AttentionBackend], list[AttentionMetadataBuilder]]:
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
@ -3098,12 +3103,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
builders = []
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
builders.append(attn_metadata_builder_i)
if self.parallel_config.enable_microbatching:
attn_metadata_builder_2 = attn_backend_i.get_builder_cls()(
kv_cache_spec,
layer_names,
self.vllm_config,
self.device,
)
builders.append(attn_metadata_builder_2)
if (self.full_cuda_graph
and not attn_metadata_builder_i.full_cudagraph_supported):
@ -3111,7 +3127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
f"Full CUDAGraph not supported for "
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
f"full_cuda_graph or use a different attention backend.")
return attn_backend_i, attn_metadata_builder_i
return attn_backend_i, builders
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
"""
@ -3124,11 +3140,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
attn_backend_i, attn_metadata_builder_i = (
attn_backend_i, attn_metadata_builders = (
self._initialize_single_attn_backend(
kv_cache_spec, kv_cache_group_spec.layer_names))
self.attn_backends.append(attn_backend_i)
self.attn_metadata_builders.append(attn_metadata_builder_i)
self.attn_metadata_builders.append(attn_metadata_builders)
if len(self.attn_backends) > 0:
return
@ -3157,11 +3173,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert len(attn_specs) == len(attn_layers), \
"All or none of the layers are expected to be encoder-only"
attn_backend, attn_metadata_builder = (
attn_backend, attn_metadata_builders = (
self._initialize_single_attn_backend(attn_specs[0],
attn_layers.keys()))
self.attn_backends.append(attn_backend)
self.attn_metadata_builders.append(attn_metadata_builder)
self.attn_metadata_builders.append(attn_metadata_builders)
self.is_encoder_only_model = True
def may_reinitialize_input_batch(self,