fix merge

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
Bill Nell 2025-05-28 23:29:30 +00:00
parent cad6447664
commit 03b41b6cad
2 changed files with 6 additions and 3 deletions

View File

@ -63,7 +63,6 @@ requires_pplx = pytest.mark.skipif(
reason="Requires PPLX kernels",
)
@dataclasses.dataclass
class ProcessGroupInfo:
world_size: int
@ -74,6 +73,11 @@ class ProcessGroupInfo:
device: torch.device
@pytest.fixture(scope="function", autouse=True)
def use_pplx_backend(monkeypatch):
monkeypatch.setenv("VLLM_ALL2ALL_BACKEND", "pplx")
def _worker_parallel_launch(
local_rank: int,
world_size: int,

View File

@ -429,8 +429,6 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
_, block_k = self.block_shape
num_tokens, hidden_dim = a1.size()
topk = topk_ids.size(1)
@ -453,6 +451,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
device=a1.device)
if self.qtype is not None:
_, block_k = self.block_shape
k_tiles = (hidden_dim + block_k - 1) // block_k
b_a1_scale = torch.zeros(
(num_local_experts, self.max_num_tokens, k_tiles),