mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-17 17:50:15 +08:00
add support for multiple builders in the model runner
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
6b0c303ab4
commit
5bbfd95bdb
@ -53,6 +53,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.buffers = buffers
|
self.buffers = buffers
|
||||||
|
for buffer in self.buffers:
|
||||||
|
buffer.set_num_sms(4)
|
||||||
self.max_tokens_per_rank = max_tokens_per_rank
|
self.max_tokens_per_rank = max_tokens_per_rank
|
||||||
self.use_fp8_dispatch = use_fp8_dispatch
|
self.use_fp8_dispatch = use_fp8_dispatch
|
||||||
# The dispatch function returns a handle that the combine function
|
# The dispatch function returns a handle that the combine function
|
||||||
|
|||||||
@ -444,6 +444,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
# 2*(192*128)*(64*1024) = 3gb
|
# 2*(192*128)*(64*1024) = 3gb
|
||||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||||
128 * 1024)
|
128 * 1024)
|
||||||
|
self.chunked_prefill_workspace_size = scheduler_config.max_num_seqs * cache_config.block_size
|
||||||
assert self.chunked_prefill_workspace_size >= \
|
assert self.chunked_prefill_workspace_size >= \
|
||||||
scheduler_config.max_num_seqs * cache_config.block_size
|
scheduler_config.max_num_seqs * cache_config.block_size
|
||||||
self.chunked_prefill_workspace = torch.empty(
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
@ -566,8 +567,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
decode_threshold=1)
|
decode_threshold=1)
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor):
|
||||||
ubatch_id: Optional[int] = None):
|
|
||||||
return MLACommonDecodeMetadata(
|
return MLACommonDecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
@ -600,8 +600,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
fast_build: bool = False,
|
fast_build: bool = False) -> M:
|
||||||
ubatch_id: Optional[int] = None) -> M:
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
num_tokens = common_attn_metadata.num_actual_tokens
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
max_query_len = common_attn_metadata.max_query_len
|
max_query_len = common_attn_metadata.max_query_len
|
||||||
@ -724,7 +723,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
|||||||
decode_metadata = self._build_decode(
|
decode_metadata = self._build_decode(
|
||||||
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
block_table_tensor=block_table_tensor[:num_decodes, ...],
|
||||||
seq_lens=seq_lens[:num_decodes],
|
seq_lens=seq_lens[:num_decodes],
|
||||||
ubatch_id=ubatch_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_metadata = self.metadata_cls(
|
attn_metadata = self.metadata_cls(
|
||||||
|
|||||||
@ -65,15 +65,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
vllm_config.parallel_config)
|
vllm_config.parallel_config)
|
||||||
|
|
||||||
self.cg_buf_tile_scheduler_metadata = [None, None]
|
self.cg_buf_tile_scheduler_metadata = None
|
||||||
self.cg_buf_num_splits = [None, None]
|
self.cg_buf_num_splits = None
|
||||||
|
|
||||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||||
ubatch_id: Optional[int] = None) -> FlashMLADecodeMetadata:
|
|
||||||
# print(f"UBATCH ID: {ubatch_id}")
|
# print(f"UBATCH ID: {ubatch_id}")
|
||||||
ubatch_id = 0 if ubatch_id is None else ubatch_id
|
|
||||||
assert ubatch_id < 2
|
|
||||||
tile_scheduler_metadata, num_splits = \
|
tile_scheduler_metadata, num_splits = \
|
||||||
get_mla_metadata(
|
get_mla_metadata(
|
||||||
seq_lens,
|
seq_lens,
|
||||||
@ -85,28 +82,28 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
|||||||
# if False:
|
# if False:
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
# First time around (CUDAGraph capture), allocate the static buffer
|
# First time around (CUDAGraph capture), allocate the static buffer
|
||||||
if self.cg_buf_num_splits[ubatch_id] is None:
|
if self.cg_buf_num_splits is None:
|
||||||
# logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}")
|
# logger.info(f"ALLOCATING FLASH MLA DATA FOR SIZE {n}")
|
||||||
self.cg_buf_num_splits[ubatch_id] = num_splits
|
self.cg_buf_num_splits = num_splits
|
||||||
self.cg_buf_tile_scheduler_metadata[ubatch_id] = tile_scheduler_metadata
|
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||||
elif n <= self.cg_buf_num_splits[ubatch_id].size(0):
|
elif n <= self.cg_buf_num_splits.size(0):
|
||||||
assert self.cg_buf_tile_scheduler_metadata[ubatch_id] is not None
|
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||||
|
|
||||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||||
assert (self.cg_buf_tile_scheduler_metadata[ubatch_id].size() ==
|
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||||
tile_scheduler_metadata.size())
|
tile_scheduler_metadata.size())
|
||||||
self.cg_buf_tile_scheduler_metadata[ubatch_id].\
|
self.cg_buf_tile_scheduler_metadata.\
|
||||||
copy_(tile_scheduler_metadata)
|
copy_(tile_scheduler_metadata)
|
||||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata[ubatch_id]
|
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||||
|
|
||||||
# Num splits is per-batch, varying size (batch_size,)
|
# Num splits is per-batch, varying size (batch_size,)
|
||||||
n = num_splits.size(0)
|
n = num_splits.size(0)
|
||||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||||
# make sure static buffer is large enough
|
# make sure static buffer is large enough
|
||||||
assert n <= self.cg_buf_num_splits[ubatch_id].size(0)
|
assert n <= self.cg_buf_num_splits.size(0)
|
||||||
num_splits_view = self.cg_buf_num_splits[ubatch_id][:n]
|
num_splits_view = self.cg_buf_num_splits[:n]
|
||||||
num_splits_view.copy_(num_splits)
|
num_splits_view.copy_(num_splits)
|
||||||
self.cg_buf_num_splits[ubatch_id][n:].fill_(0) # fill the rest with 0s
|
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||||
num_splits = num_splits_view
|
num_splits = num_splits_view
|
||||||
|
|
||||||
return FlashMLADecodeMetadata(
|
return FlashMLADecodeMetadata(
|
||||||
|
|||||||
@ -197,7 +197,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# self.model: nn.Module # Set after load_model
|
# self.model: nn.Module # Set after load_model
|
||||||
# Initialize in initialize_kv_cache
|
# Initialize in initialize_kv_cache
|
||||||
self.kv_caches: list[torch.Tensor] = []
|
self.kv_caches: list[torch.Tensor] = []
|
||||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
# Outer index is kv cache index, inner index is ubatch id
|
||||||
|
self.attn_metadata_builders: list[list[AttentionMetadataBuilder]] = []
|
||||||
self.attn_backends: list[type[AttentionBackend]] = []
|
self.attn_backends: list[type[AttentionBackend]] = []
|
||||||
# self.kv_cache_config: KVCacheConfig
|
# self.kv_cache_config: KVCacheConfig
|
||||||
|
|
||||||
@ -361,6 +362,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
self.kv_sharing_fast_prefill_logits_indices = torch.zeros(
|
||||||
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
self.max_num_tokens, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
def get_builder(self, index: int, ubatch_id: Optional[int] = None) -> AttentionMetadataBuilder:
|
||||||
|
if ubatch_id is None:
|
||||||
|
return self.attn_metadata_builders[index][0]
|
||||||
|
else:
|
||||||
|
return self.attn_metadata_builders[index][ubatch_id]
|
||||||
|
|
||||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None:
|
||||||
"""
|
"""
|
||||||
Update the order of requests in the batch based on the attention
|
Update the order of requests in the batch based on the attention
|
||||||
@ -379,7 +386,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
if len(self.kv_cache_config.kv_cache_groups) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.attn_metadata_builders[0].reorder_batch(self.input_batch,
|
self.get_builder(0).reorder_batch(self.input_batch,
|
||||||
scheduler_output)
|
scheduler_output)
|
||||||
|
|
||||||
# For models with multiple KV cache groups, the groups should agree on
|
# For models with multiple KV cache groups, the groups should agree on
|
||||||
@ -390,7 +397,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
# re-order the batch (not only the first).
|
# re-order the batch (not only the first).
|
||||||
# TODO(tdoublep): verify this during engine init instead of at runtime
|
# TODO(tdoublep): verify this during engine init instead of at runtime
|
||||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||||
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
|
batch_reordered = self.get_builder(i).reorder_batch(
|
||||||
self.input_batch, scheduler_output)
|
self.input_batch, scheduler_output)
|
||||||
assert not batch_reordered
|
assert not batch_reordered
|
||||||
|
|
||||||
@ -943,7 +950,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
# Prepare for cascade attention if enabled & beneficial.
|
# Prepare for cascade attention if enabled & beneficial.
|
||||||
common_prefix_len = 0
|
common_prefix_len = 0
|
||||||
builder = self.attn_metadata_builders[kv_cache_group_id]
|
builder = self.get_builder(kv_cache_group_id)
|
||||||
if self.cascade_attn_enabled:
|
if self.cascade_attn_enabled:
|
||||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||||
num_scheduled_tokens,
|
num_scheduled_tokens,
|
||||||
@ -960,10 +967,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
common_attn_metadata_list):
|
common_attn_metadata_list):
|
||||||
assert common_attn_metadata.max_query_len == 1
|
assert common_attn_metadata.max_query_len == 1
|
||||||
attn_metadata_i = (
|
attn_metadata_i = (
|
||||||
builder.build(
|
self.get_builder(kv_cache_group_id, ubatch_id=ubid).build(
|
||||||
common_prefix_len=common_prefix_len,
|
common_prefix_len=common_prefix_len,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata))
|
||||||
ubatch_id=ubid))
|
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
@ -1023,8 +1029,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
attn_metadata[layer_name] = local_attn_metadata_i
|
attn_metadata[layer_name] = local_attn_metadata_i
|
||||||
|
|
||||||
attention_cuda_graphs = all(
|
attention_cuda_graphs = all(
|
||||||
b.can_run_in_cudagraph(common_attn_metadata)
|
builder_list[0].can_run_in_cudagraph(common_attn_metadata)
|
||||||
for b in self.attn_metadata_builders)
|
for builder_list in self.attn_metadata_builders)
|
||||||
|
|
||||||
# Hot-Swap lora model
|
# Hot-Swap lora model
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
@ -1897,27 +1903,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
is_dummy_run=is_dummy_run)
|
is_dummy_run=is_dummy_run)
|
||||||
if num_scheduled_tokens not in self.cudagraphs \
|
if num_scheduled_tokens not in self.cudagraphs \
|
||||||
and not skip_cuda_graphs and build_cuda_graph:
|
and not skip_cuda_graphs and build_cuda_graph:
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"CAPTURING {num_scheduled_tokens}")
|
logger.info(f"CAPTURING {num_scheduled_tokens}")
|
||||||
return self._capture_ubatches(ubatch_metadata, self.model)
|
return self._capture_ubatches(ubatch_metadata, self.model)
|
||||||
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
elif num_scheduled_tokens in self.cudagraphs and not skip_cuda_graphs:
|
||||||
# assert False
|
# assert False
|
||||||
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
cudagraph_metadata = self.cudagraphs[num_scheduled_tokens]
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
logger.info(f"UBATCH REPLAY {num_scheduled_tokens}")
|
||||||
cudagraph_metadata.cudagraph.replay()
|
cudagraph_metadata.cudagraph.replay()
|
||||||
return cudagraph_metadata.outputs
|
return cudagraph_metadata.outputs
|
||||||
else:
|
else:
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
logger.info(f"RUNNING NORMALLY {num_scheduled_tokens}")
|
||||||
return self._run_ubatches(ubatch_metadata, self.model)
|
return self._run_ubatches(ubatch_metadata, self.model)
|
||||||
# run normal batch
|
# run normal batch
|
||||||
else:
|
else:
|
||||||
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
input_ids, positions, inputs_embeds, intermediate_tensors = \
|
||||||
self.model_inputs(slice(0, num_scheduled_tokens),
|
self.model_inputs(slice(0, num_scheduled_tokens),
|
||||||
scheduler_output, is_dummy_run)
|
scheduler_output, is_dummy_run)
|
||||||
# if is_global_first_rank():
|
if is_global_first_rank():
|
||||||
# logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
logger.info(f"RUNNING FULL BATCH {num_scheduled_tokens}")
|
||||||
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
skip_cuda_graphs = self.parallel_config.enable_microbatching
|
||||||
with set_forward_context(attn_metadata,
|
with set_forward_context(attn_metadata,
|
||||||
vllm_config=self.vllm_config,
|
vllm_config=self.vllm_config,
|
||||||
@ -2725,18 +2731,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
for ubid, common_attn_metadata in enumerate(common_attn_metadata_list):
|
for ubid, common_attn_metadata in enumerate(common_attn_metadata_list):
|
||||||
attn_metadata_i = (
|
attn_metadata_i = (
|
||||||
self.attn_metadata_builders[kv_cache_group_id].
|
self.get_builder(kv_cache_group_id, ubatch_id=ubid).
|
||||||
build(
|
build(
|
||||||
common_prefix_len=0,
|
common_prefix_len=0,
|
||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata
|
||||||
ubatch_id=ubid
|
|
||||||
))
|
))
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
assert type(attn_metadata) is list
|
assert type(attn_metadata) is list
|
||||||
attn_metadata[ubid][layer_name] = attn_metadata_i
|
attn_metadata[ubid][layer_name] = attn_metadata_i
|
||||||
else:
|
else:
|
||||||
attn_metadata_i = self.attn_metadata_builders[
|
attn_metadata_i = self.get_builder(
|
||||||
kv_cache_group_id].build_for_cudagraph_capture(
|
kv_cache_group_id).build_for_cudagraph_capture(
|
||||||
common_attn_metadata)
|
common_attn_metadata)
|
||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
@ -3071,7 +3076,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
def _initialize_single_attn_backend(
|
def _initialize_single_attn_backend(
|
||||||
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
self, kv_cache_spec: KVCacheSpec, layer_names: list[str]
|
||||||
) -> tuple[AttentionBackend, AttentionMetadataBuilder]:
|
) -> tuple[type[AttentionBackend], list[AttentionMetadataBuilder]]:
|
||||||
if isinstance(kv_cache_spec, AttentionSpec):
|
if isinstance(kv_cache_spec, AttentionSpec):
|
||||||
attn_backend_i = get_attn_backend(
|
attn_backend_i = get_attn_backend(
|
||||||
kv_cache_spec.head_size,
|
kv_cache_spec.head_size,
|
||||||
@ -3098,12 +3103,23 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||||
|
|
||||||
|
builders = []
|
||||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||||
kv_cache_spec,
|
kv_cache_spec,
|
||||||
layer_names,
|
layer_names,
|
||||||
self.vllm_config,
|
self.vllm_config,
|
||||||
self.device,
|
self.device,
|
||||||
)
|
)
|
||||||
|
builders.append(attn_metadata_builder_i)
|
||||||
|
if self.parallel_config.enable_microbatching:
|
||||||
|
attn_metadata_builder_2 = attn_backend_i.get_builder_cls()(
|
||||||
|
kv_cache_spec,
|
||||||
|
layer_names,
|
||||||
|
self.vllm_config,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
builders.append(attn_metadata_builder_2)
|
||||||
|
|
||||||
|
|
||||||
if (self.full_cuda_graph
|
if (self.full_cuda_graph
|
||||||
and not attn_metadata_builder_i.full_cudagraph_supported):
|
and not attn_metadata_builder_i.full_cudagraph_supported):
|
||||||
@ -3111,7 +3127,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
f"Full CUDAGraph not supported for "
|
f"Full CUDAGraph not supported for "
|
||||||
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
|
||||||
f"full_cuda_graph or use a different attention backend.")
|
f"full_cuda_graph or use a different attention backend.")
|
||||||
return attn_backend_i, attn_metadata_builder_i
|
return attn_backend_i, builders
|
||||||
|
|
||||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||||
"""
|
"""
|
||||||
@ -3124,11 +3140,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
kv_cache_config.kv_cache_groups):
|
kv_cache_config.kv_cache_groups):
|
||||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||||
|
|
||||||
attn_backend_i, attn_metadata_builder_i = (
|
attn_backend_i, attn_metadata_builders = (
|
||||||
self._initialize_single_attn_backend(
|
self._initialize_single_attn_backend(
|
||||||
kv_cache_spec, kv_cache_group_spec.layer_names))
|
kv_cache_spec, kv_cache_group_spec.layer_names))
|
||||||
self.attn_backends.append(attn_backend_i)
|
self.attn_backends.append(attn_backend_i)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
self.attn_metadata_builders.append(attn_metadata_builders)
|
||||||
|
|
||||||
if len(self.attn_backends) > 0:
|
if len(self.attn_backends) > 0:
|
||||||
return
|
return
|
||||||
@ -3157,11 +3173,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
assert len(attn_specs) == len(attn_layers), \
|
assert len(attn_specs) == len(attn_layers), \
|
||||||
"All or none of the layers are expected to be encoder-only"
|
"All or none of the layers are expected to be encoder-only"
|
||||||
|
|
||||||
attn_backend, attn_metadata_builder = (
|
attn_backend, attn_metadata_builders = (
|
||||||
self._initialize_single_attn_backend(attn_specs[0],
|
self._initialize_single_attn_backend(attn_specs[0],
|
||||||
attn_layers.keys()))
|
attn_layers.keys()))
|
||||||
self.attn_backends.append(attn_backend)
|
self.attn_backends.append(attn_backend)
|
||||||
self.attn_metadata_builders.append(attn_metadata_builder)
|
self.attn_metadata_builders.append(attn_metadata_builders)
|
||||||
self.is_encoder_only_model = True
|
self.is_encoder_only_model = True
|
||||||
|
|
||||||
def may_reinitialize_input_batch(self,
|
def may_reinitialize_input_batch(self,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user