mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 08:17:03 +08:00
fix merge
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
parent
cad6447664
commit
03b41b6cad
@ -63,7 +63,6 @@ requires_pplx = pytest.mark.skipif(
|
||||
reason="Requires PPLX kernels",
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
@ -74,6 +73,11 @@ class ProcessGroupInfo:
|
||||
device: torch.device
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_pplx_backend(monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "pplx")
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
|
||||
@ -429,8 +429,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
_, block_k = self.block_shape
|
||||
|
||||
num_tokens, hidden_dim = a1.size()
|
||||
topk = topk_ids.size(1)
|
||||
|
||||
@ -453,6 +451,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
device=a1.device)
|
||||
|
||||
if self.qtype is not None:
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = (hidden_dim + block_k - 1) // block_k
|
||||
b_a1_scale = torch.zeros(
|
||||
(num_local_experts, self.max_num_tokens, k_tiles),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user