mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 19:57:08 +08:00
cleanup flashmla.py
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
1d75a029a9
commit
7e2ff2620e
@ -78,11 +78,11 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
if self.runner.full_cuda_graph:
|
||||
n = num_splits.size(0)
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_num_splits is None:
|
||||
self.cg_buf_num_splits = num_splits
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
elif n <= self.cg_buf_num_splits.size(0):
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
@ -93,7 +93,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# logger.info(f"N: {n} num splits {self.cg_buf_num_splits.size(0)}")
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user