mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 10:17:07 +08:00
fix with new main branch
Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
parent
b3e31b42d8
commit
a7ea23d16d
@ -20,7 +20,9 @@ import torch
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
from vllm.attention.selector import get_attn_backend
|
||||||
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
@ -29,8 +31,7 @@ from vllm.distributed.parallel_state import (
|
|||||||
get_tp_group, get_world_group)
|
get_tp_group, get_world_group)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import _Backend
|
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||||
from vllm.utils import get_ip, make_zmq_path, make_zmq_socket
|
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
from weakref import ref as weakref_ref
|
from weakref import ref as weakref_ref
|
||||||
@ -38,6 +39,7 @@ from weakref import ref as weakref_ref
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
@ -835,7 +837,7 @@ class MoRIIOConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
class MoRIIOConnector(KVConnectorBase_V1):
|
class MoRIIOConnector(KVConnectorBase_V1):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole, kv_cache_config: Optional["KVCacheConfig"] = None,):
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
# assert vllm_config.kv_transfer_config.engine_id is not None
|
# assert vllm_config.kv_transfer_config.engine_id is not None
|
||||||
self.engine_id = str(
|
self.engine_id = str(
|
||||||
@ -927,6 +929,16 @@ class MoRIIOConnector(KVConnectorBase_V1):
|
|||||||
|
|
||||||
def wait_for_save(self):
|
def wait_for_save(self):
|
||||||
pass
|
pass
|
||||||
|
def has_connector_metadata(self) -> bool:
|
||||||
|
"""Check whether the connector metadata is currently set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if connector metadata exists, False otherwise.
|
||||||
|
"""
|
||||||
|
try :
|
||||||
|
return self._connector_metadata is not None
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MoRIIOConnectorScheduler:
|
class MoRIIOConnectorScheduler:
|
||||||
@ -1402,8 +1414,11 @@ class MoRIIOConnectorWorker:
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
use_mla=self.use_mla)
|
use_mla=self.use_mla)
|
||||||
self.backend_name = backend.get_name()
|
self.backend_name = backend.get_name()
|
||||||
attn_backend = backend_name_to_enum(self.backend_name)
|
attn_backend = AttentionBackendEnum[self.backend_name]
|
||||||
self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
self._use_flashinfer = attn_backend == AttentionBackendEnum.FLASHINFER
|
||||||
|
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
|
||||||
|
# attn_backend = backend_name_to_enum(self.backend_name)
|
||||||
|
# self._use_flashinfer = attn_backend == _Backend.FLASHINFER
|
||||||
logger.debug("Detected attention backend %s", self.backend_name)
|
logger.debug("Detected attention backend %s", self.backend_name)
|
||||||
|
|
||||||
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user