clean up code

Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-12-01 07:35:10 +00:00
parent 536668602c
commit 1742f0cdfb

View File

@ -1351,8 +1351,6 @@ class MoRIIOConnectorWorker:
logger.info("Initializing MoRIIO worker %s", engine_id)
logging.getLogger("aiter").disabled = True
# Config.
self.vllm_config = vllm_config
assert vllm_config.kv_transfer_config is not None, (
@ -1507,12 +1505,9 @@ class MoRIIOConnectorWorker:
self.block_size,
use_mla=self.use_mla,
)
#TODO: consider the integration of flashinfer or other backends.
self.backend_name = backend.get_name()
attn_backend = AttentionBackendEnum[self.backend_name]
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
# attn_backend = backend_name_to_enum(self.backend_name)
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER
logger.debug("Detected attention backend %s", self.backend_name)
def schedule_write_blocks(
@ -1854,13 +1849,8 @@ class MoRIIOConnectorWorker:
self.slot_size_bytes = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
if self._use_flashinfer:
# FlashInfer swaps 2<->num_blocks dimensions.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 4 # [2, block_size, kv_heads, head_dim]
else:
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
block_size, n_kv_heads, head_dim = block_shape[-3:]
# head size in bytes.
@ -1884,7 +1874,7 @@ class MoRIIOConnectorWorker:
for cache_or_caches in kv_caches.values():
cache_list = (
[cache_or_caches]
if use_mla or self._use_flashinfer
if use_mla
else cache_or_caches
)
for cache in cache_list: