mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 23:07:27 +08:00
Minor
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
1260e43230
commit
ba64a0249f
@ -38,9 +38,10 @@ class BlockTable:
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.num_blocks_per_row = np.zeros((max_num_reqs,), dtype=np.int32)
|
||||
|
||||
# UVA requires pinned memory.
|
||||
self.use_uva = is_uva_available() and pin_memory
|
||||
self.use_uva = is_uva_available() and pin_memory and False
|
||||
if self.use_uva:
|
||||
logger.info("Using Unified Virtual Addressing (UVA) for block "
|
||||
"table transfer.")
|
||||
@ -62,6 +63,7 @@ class BlockTable:
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, :num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = num_blocks
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = 0
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
@ -74,6 +76,7 @@ class BlockTable:
|
||||
) -> None:
|
||||
num_blocks = len(block_ids)
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
if self.use_uva:
|
||||
self.block_table_diff_np[row_idx, 0] = start
|
||||
# Move-and-append is not allowed.
|
||||
@ -81,7 +84,10 @@ class BlockTable:
|
||||
self.block_table_diff_np[row_idx, 1] = num_blocks
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
self.block_table_np[tgt] = self.block_table_np[src]
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = \
|
||||
self.block_table_np[src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
if self.use_uva:
|
||||
# Append-and-move is allowed.
|
||||
self.block_table_diff_np[tgt] = self.block_table_diff_np[src]
|
||||
@ -108,6 +114,7 @@ class BlockTable:
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
self.num_blocks_per_row.fill(0)
|
||||
if self.use_uva:
|
||||
self.block_table_diff.fill_(0)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user