mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-03 09:17:03 +08:00
clean up code
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
536668602c
commit
1742f0cdfb
@ -1351,8 +1351,6 @@ class MoRIIOConnectorWorker:
|
||||
|
||||
logger.info("Initializing MoRIIO worker %s", engine_id)
|
||||
|
||||
logging.getLogger("aiter").disabled = True
|
||||
|
||||
# Config.
|
||||
self.vllm_config = vllm_config
|
||||
assert vllm_config.kv_transfer_config is not None, (
|
||||
@ -1507,12 +1505,9 @@ class MoRIIOConnectorWorker:
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
|
||||
#TODO: consider the integration of flashinfer or other backends.
|
||||
self.backend_name = backend.get_name()
|
||||
attn_backend = AttentionBackendEnum[self.backend_name]
|
||||
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
|
||||
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||
# attn_backend = backend_name_to_enum(self.backend_name)
|
||||
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
def schedule_write_blocks(
|
||||
@ -1854,13 +1849,8 @@ class MoRIIOConnectorWorker:
|
||||
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
||||
else:
|
||||
# [2 (k and v), num_blocks, ...]
|
||||
if self._use_flashinfer:
|
||||
# FlashInfer swaps 2<->num_blocks dimensions.
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 4 # [2, block_size, kv_heads, head_dim]
|
||||
else:
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
self.num_blocks = first_kv_cache.shape[1]
|
||||
block_rank = 3 # [block_size, kv_heads, head_dim]
|
||||
block_shape = first_kv_cache.shape[-block_rank:]
|
||||
block_size, n_kv_heads, head_dim = block_shape[-3:]
|
||||
# head size in bytes.
|
||||
@ -1884,7 +1874,7 @@ class MoRIIOConnectorWorker:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
cache_list = (
|
||||
[cache_or_caches]
|
||||
if use_mla or self._use_flashinfer
|
||||
if use_mla
|
||||
else cache_or_caches
|
||||
)
|
||||
for cache in cache_list:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user