mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:36:20 +08:00
[Kernels] Clean up FusedMoeMethodBase and modular kernel setup. Remove extra arguments from modular kernel methods. (#22035)
Signed-off-by: Bill Nell <bnell@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
parent
48b01fd4d4
commit
8ad7285ea2
@ -399,6 +399,7 @@ steps:
|
|||||||
- label: Kernels MoE Test %N
|
- label: Kernels MoE Test %N
|
||||||
mirror_hardwares: [amdexperimental]
|
mirror_hardwares: [amdexperimental]
|
||||||
source_file_dependencies:
|
source_file_dependencies:
|
||||||
|
- csrc/quantization/cutlass_w8a8/moe/
|
||||||
- csrc/moe/
|
- csrc/moe/
|
||||||
- tests/kernels/moe
|
- tests/kernels/moe
|
||||||
- vllm/model_executor/layers/fused_moe/
|
- vllm/model_executor/layers/fused_moe/
|
||||||
|
|||||||
@ -175,11 +175,19 @@ implementations that input `FusedMoEActivationFormat.Standard` support chunking
|
|||||||
|
|
||||||
### FusedMoEModularKernel Initialization
|
### FusedMoEModularKernel Initialization
|
||||||
|
|
||||||
`FusedMoEMethodBase` class has 2 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
|
`FusedMoEMethodBase` class has 3 methods that are collectively responsible in creating the `FusedMoEModularKernel` object. They are,
|
||||||
|
|
||||||
|
* maybe_make_prepare_finalize,
|
||||||
* select_gemm_impl, and
|
* select_gemm_impl, and
|
||||||
* init_prepare_finalize
|
* init_prepare_finalize
|
||||||
|
|
||||||
|
#### maybe_make_prepare_finalize
|
||||||
|
|
||||||
|
The `maybe_make_prepare_finalize` method is responsbile for constructing an instance of `FusedMoEPrepareAndFinalize` when appropriate based on the current all2all backend, e.g. when EP + DP is enabled. The base class method currently constructs all the `FusedMoEPrepareAndFinalize` objects for the EP+DP case. Derived classes can override this method to construct prepare/finalize objects for different scenarios, e.g. `ModelOptNvFp4FusedMoE` can construct a `FlashInferCutlassMoEPrepareAndFinalize` for the EP+TP case.
|
||||||
|
Please refer to the implementations in,
|
||||||
|
|
||||||
|
* `ModelOptNvFp4FusedMoE`
|
||||||
|
|
||||||
#### select_gemm_impl
|
#### select_gemm_impl
|
||||||
|
|
||||||
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
The `select_gemm_impl` method is undefined in the base class. It is the responsibility of the derived class to implement a method that constructs a valid/appropriate `FusedMoEPermuteExpertsUnpermute` object.
|
||||||
|
|||||||
@ -70,12 +70,27 @@ def parse_args():
|
|||||||
default=64,
|
default=64,
|
||||||
help=("Maximum number of sequences to be processed in a single iteration."),
|
help=("Maximum number of sequences to be processed in a single iteration."),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-model-len",
|
||||||
|
type=int,
|
||||||
|
help=("Maximum number of tokens to be processed in a single iteration."),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help=("Number of seconds before unresponsive process is killed."),
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpu-memory-utilization",
|
"--gpu-memory-utilization",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.8,
|
default=0.8,
|
||||||
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
help=("Fraction of GPU memory vLLM is allowed to allocate (0.0, 1.0]."),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--quantization",
|
||||||
|
type=str,
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -90,7 +105,9 @@ def main(
|
|||||||
enforce_eager,
|
enforce_eager,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
max_num_seqs,
|
max_num_seqs,
|
||||||
|
max_model_len,
|
||||||
gpu_memory_utilization,
|
gpu_memory_utilization,
|
||||||
|
quantization,
|
||||||
):
|
):
|
||||||
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
|
||||||
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
|
||||||
@ -142,7 +159,9 @@ def main(
|
|||||||
enable_expert_parallel=True,
|
enable_expert_parallel=True,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_model_len=max_model_len,
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
quantization=quantization,
|
||||||
)
|
)
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
@ -198,14 +217,16 @@ if __name__ == "__main__":
|
|||||||
args.enforce_eager,
|
args.enforce_eager,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
args.max_num_seqs,
|
args.max_num_seqs,
|
||||||
|
args.max_model_len,
|
||||||
args.gpu_memory_utilization,
|
args.gpu_memory_utilization,
|
||||||
|
args.quantization,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
procs.append(proc)
|
procs.append(proc)
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
for proc in procs:
|
for proc in procs:
|
||||||
proc.join(timeout=300)
|
proc.join(timeout=args.timeout)
|
||||||
if proc.exitcode is None:
|
if proc.exitcode is None:
|
||||||
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
print(f"Killing process {proc.pid} that didn't stop within 5 minutes.")
|
||||||
proc.kill()
|
proc.kill()
|
||||||
|
|||||||
@ -7,41 +7,22 @@ import torch
|
|||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from tests.kernels.moe.utils import make_test_weights, per_token_cast_to_fp8
|
||||||
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
|
FLOAT8_E4M3_MAX,
|
||||||
|
dequantize_nvfp4_to_dtype)
|
||||||
from tests.kernels.utils import torch_experts
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||||
# Fused experts and PrepareFinalize imports
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
|
||||||
BatchedDeepGemmExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
|
||||||
BatchedTritonOrDeepGemmExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
||||||
BatchedTritonExperts, NaiveBatchedExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
|
||||||
TritonExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
|
||||||
MoEPrepareAndFinalizeNoEP)
|
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
||||||
TritonOrDeepGemmExperts)
|
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
|
|
||||||
|
from .mk_objects import (expert_info, make_fused_experts,
|
||||||
|
make_prepare_finalize, prepare_finalize_info)
|
||||||
from .parallel_utils import ProcessGroupInfo
|
from .parallel_utils import ProcessGroupInfo
|
||||||
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
|
|
||||||
make_quant_fp8_weights, per_token_cast_to_fp8)
|
|
||||||
|
|
||||||
if has_pplx():
|
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
|
||||||
PplxPrepareAndFinalize)
|
|
||||||
if has_deep_ep():
|
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
|
||||||
DeepEPHTPrepareAndFinalize)
|
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
|
||||||
DeepEPLLPrepareAndFinalize)
|
|
||||||
|
|
||||||
|
|
||||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||||
@ -69,24 +50,31 @@ class Config:
|
|||||||
|
|
||||||
torch_trace_dir_path: Optional[str] = None
|
torch_trace_dir_path: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.quant_config is None:
|
||||||
|
self.quant_config = FusedMoEQuantConfig()
|
||||||
|
|
||||||
def describe(self) -> str:
|
def describe(self) -> str:
|
||||||
s = ""
|
s = ""
|
||||||
s += "== Config: \n"
|
s += "== Config:\n"
|
||||||
s += f" world_size={self.world_size} \n"
|
s += f" world_size={self.world_size}\n"
|
||||||
s += f" PF={self.prepare_finalize_type.__name__} \n"
|
s += f" PF={self.prepare_finalize_type.__name__}\n"
|
||||||
s += f" FE={self.fused_experts_type.__name__} \n"
|
s += f" FE={self.fused_experts_type.__name__}\n"
|
||||||
s += f" topk={self.topks} \n"
|
s += f" E={self.E}\n"
|
||||||
s += f" dtype={self.dtype} \n"
|
s += f" Ms={self.Ms}\n"
|
||||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
|
s += f" N={self.N}\n"
|
||||||
s += " Quant: \n"
|
s += f" K={self.K}\n"
|
||||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
|
s += f" topk={self.topks}\n"
|
||||||
|
s += f" dtype={self.dtype}\n"
|
||||||
|
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size}\n"
|
||||||
|
s += " Quant:\n"
|
||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
s += f" q_dtype={self.quant_dtype} \n"
|
s += f" q_dtype={self.quant_dtype}\n"
|
||||||
s += f" q_block_shape={self.quant_block_shape} \n"
|
s += f" q_block_shape={self.quant_block_shape}\n"
|
||||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
|
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant}\n"
|
||||||
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
|
s += f" q_per_act_token={self.is_per_act_token_quant}\n"
|
||||||
else:
|
else:
|
||||||
s += " quant=None \n"
|
s += " quant=None\n"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -95,34 +83,28 @@ class Config:
|
|||||||
return self.Ms
|
return self.Ms
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||||
if self.quant_config is None:
|
assert self.quant_config is not None
|
||||||
return None
|
|
||||||
return self.quant_config.quant_dtype
|
return self.quant_config.quant_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_per_act_token_quant(self) -> bool:
|
def is_per_act_token_quant(self) -> bool:
|
||||||
if self.quant_config is None:
|
assert self.quant_config is not None
|
||||||
return False
|
|
||||||
return self.quant_config.per_act_token_quant
|
return self.quant_config.per_act_token_quant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_per_tensor_act_quant(self) -> bool:
|
def is_per_tensor_act_quant(self) -> bool:
|
||||||
if self.quant_config is None:
|
|
||||||
return False
|
|
||||||
return (not self.is_per_act_token_quant
|
return (not self.is_per_act_token_quant
|
||||||
and self.quant_block_shape is None)
|
and self.quant_block_shape is None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_per_out_ch_quant(self) -> bool:
|
def is_per_out_ch_quant(self) -> bool:
|
||||||
if self.quant_config is None:
|
assert self.quant_config is not None
|
||||||
return False
|
|
||||||
return self.quant_config.per_out_ch_quant
|
return self.quant_config.per_out_ch_quant
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quant_block_shape(self) -> Optional[list[int]]:
|
def quant_block_shape(self) -> Optional[list[int]]:
|
||||||
if self.quant_config is None:
|
assert self.quant_config is not None
|
||||||
return None
|
|
||||||
return self.quant_config.block_shape
|
return self.quant_config.block_shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -130,36 +112,30 @@ class Config:
|
|||||||
assert isinstance(self.topks, int)
|
assert isinstance(self.topks, int)
|
||||||
return self.topks
|
return self.topks
|
||||||
|
|
||||||
@property
|
|
||||||
def topk_ids_dtype(self) -> Optional[torch.dtype]:
|
|
||||||
topk_ids_dtype = None
|
|
||||||
if self.prepare_finalize_type == PplxPrepareAndFinalize:
|
|
||||||
topk_ids_dtype = torch.uint32
|
|
||||||
elif self.prepare_finalize_type in [
|
|
||||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
|
||||||
]:
|
|
||||||
topk_ids_dtype = torch.int64
|
|
||||||
return topk_ids_dtype
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_local_experts(self) -> int:
|
def num_local_experts(self) -> int:
|
||||||
return self.E // self.world_size
|
return self.E // self.world_size
|
||||||
|
|
||||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||||
"""
|
"""
|
||||||
make env data for vllm launch.
|
make env data for vllm launch.
|
||||||
"""
|
"""
|
||||||
vllm_config = VllmConfig()
|
vllm_config = VllmConfig()
|
||||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||||
vllm_config.parallel_config.enable_expert_parallel = True
|
vllm_config.parallel_config.enable_expert_parallel = True
|
||||||
|
|
||||||
env_dict = {
|
env_dict = {
|
||||||
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
|
|
||||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
backend = self.all2all_backend()
|
||||||
|
if backend is not None:
|
||||||
|
env_dict.update({"VLLM_ALL2ALL_BACKEND": backend})
|
||||||
|
|
||||||
if self.fused_moe_chunk_size is not None:
|
if self.fused_moe_chunk_size is not None:
|
||||||
env_dict.update(
|
env_dict.update(
|
||||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||||
|
|
||||||
return vllm_config, env_dict
|
return vllm_config, env_dict
|
||||||
|
|
||||||
def is_fp8_block_quantized(self):
|
def is_fp8_block_quantized(self):
|
||||||
@ -167,85 +143,59 @@ class Config:
|
|||||||
and self.quant_block_shape is not None)
|
and self.quant_block_shape is not None)
|
||||||
|
|
||||||
def is_batched_prepare_finalize(self):
|
def is_batched_prepare_finalize(self):
|
||||||
return self.prepare_finalize_type in [
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||||
]
|
info.activation_format)
|
||||||
|
|
||||||
def is_batched_fused_experts(self):
|
def is_batched_fused_experts(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
|
return (mk.FusedMoEActivationFormat.BatchedExperts ==
|
||||||
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
|
info.activation_format)
|
||||||
]
|
|
||||||
|
|
||||||
def is_standard_fused_experts(self):
|
def is_standard_fused_experts(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
return mk.FusedMoEActivationFormat.Standard == info.activation_format
|
||||||
TritonExperts
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_fe_16bit_supported(self):
|
def fe_supported_types(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
return info.supported_dtypes
|
||||||
NaiveBatchedExperts, TritonExperts
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_fe_fp8_supported(self):
|
def pf_supported_types(self):
|
||||||
return self.fused_experts_type in [
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
BatchedDeepGemmExperts,
|
return info.supported_dtypes
|
||||||
BatchedTritonExperts,
|
|
||||||
BatchedTritonOrDeepGemmExperts,
|
|
||||||
CutlassExpertsFp8,
|
|
||||||
DeepGemmExperts,
|
|
||||||
TritonExperts,
|
|
||||||
TritonOrDeepGemmExperts,
|
|
||||||
NaiveBatchedExperts,
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_fe_block_fp8_supported(self):
|
def is_block_quant_supported(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
BatchedDeepGemmExperts,
|
return info.blocked_quantization_support
|
||||||
BatchedTritonOrDeepGemmExperts,
|
|
||||||
DeepGemmExperts,
|
|
||||||
TritonExperts,
|
|
||||||
TritonOrDeepGemmExperts,
|
|
||||||
BatchedTritonExperts,
|
|
||||||
NaiveBatchedExperts,
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_fe_supports_chunking(self):
|
def is_fe_supports_chunking(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
return info.supports_chunking
|
||||||
TritonExperts
|
|
||||||
]
|
def supports_expert_map(self):
|
||||||
|
info = expert_info(self.fused_experts_type)
|
||||||
|
return info.supports_expert_map
|
||||||
|
|
||||||
|
def supports_apply_weight_on_input(self):
|
||||||
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
|
return info.supports_apply_weight_on_input
|
||||||
|
|
||||||
def needs_deep_gemm(self):
|
def needs_deep_gemm(self):
|
||||||
return self.fused_experts_type in [
|
info = expert_info(self.fused_experts_type)
|
||||||
BatchedDeepGemmExperts,
|
return info.needs_deep_gemm
|
||||||
DeepGemmExperts,
|
|
||||||
]
|
|
||||||
|
|
||||||
def needs_pplx(self):
|
def needs_pplx(self):
|
||||||
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
|
return info.backend == "pplx"
|
||||||
|
|
||||||
def needs_deep_ep(self):
|
def needs_deep_ep(self):
|
||||||
return self.prepare_finalize_type in [
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
return (info.backend == "deepep_high_throughput"
|
||||||
]
|
or info.backend == "deepep_low_latency")
|
||||||
|
|
||||||
def all2all_backend(self):
|
def all2all_backend(self):
|
||||||
if self.needs_pplx():
|
info = prepare_finalize_info(self.prepare_finalize_type)
|
||||||
return "pplx"
|
return info.backend
|
||||||
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
|
|
||||||
return "deepep_high_throughput"
|
|
||||||
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
|
|
||||||
return "deepep_low_latency"
|
|
||||||
return "naive"
|
|
||||||
|
|
||||||
def needs_all2all(self):
|
|
||||||
return self.prepare_finalize_type in [
|
|
||||||
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
|
|
||||||
DeepEPLLPrepareAndFinalize
|
|
||||||
]
|
|
||||||
|
|
||||||
def is_valid(self):
|
def is_valid(self):
|
||||||
# Check prepare-finalize and fused-experts compatibility
|
# Check prepare-finalize and fused-experts compatibility
|
||||||
@ -267,28 +217,28 @@ class Config:
|
|||||||
# invalid quant config
|
# invalid quant config
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# check bf16 / fp16 support
|
# check type support
|
||||||
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
|
if self.quant_dtype is None:
|
||||||
if is_16bit and not self.is_fe_16bit_supported():
|
if (self.dtype not in self.pf_supported_types()
|
||||||
return False
|
or self.dtype not in self.fe_supported_types()):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if (self.quant_dtype not in self.pf_supported_types()
|
||||||
|
or self.quant_dtype not in self.fe_supported_types()):
|
||||||
|
return False
|
||||||
|
|
||||||
# Check fp8 support
|
# Check block quanization support
|
||||||
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
|
|
||||||
if is_fp8 and not self.is_fe_fp8_supported():
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check fp8 block quanization support
|
|
||||||
is_block_quatized = self.quant_block_shape is not None
|
is_block_quatized = self.quant_block_shape is not None
|
||||||
if is_block_quatized and not is_fp8:
|
if is_block_quatized and self.quant_dtype is None:
|
||||||
return False
|
return False
|
||||||
if is_block_quatized and not self.is_fe_block_fp8_supported():
|
if is_block_quatized and not self.is_block_quant_supported():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# deep_gemm only works with block-quantized
|
# deep_gemm only works with block-quantized
|
||||||
if self.needs_deep_gemm() and not is_block_quatized:
|
if self.needs_deep_gemm() and not is_block_quatized:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check dependencies
|
# Check dependencies (turn into asserts?)
|
||||||
if self.needs_deep_ep() and not has_deep_ep():
|
if self.needs_deep_ep() and not has_deep_ep():
|
||||||
return False
|
return False
|
||||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||||
@ -305,6 +255,8 @@ class WeightTensors:
|
|||||||
w2: torch.Tensor
|
w2: torch.Tensor
|
||||||
w1_scale: Optional[torch.Tensor]
|
w1_scale: Optional[torch.Tensor]
|
||||||
w2_scale: Optional[torch.Tensor]
|
w2_scale: Optional[torch.Tensor]
|
||||||
|
w1_gs: Optional[torch.Tensor] = None
|
||||||
|
w2_gs: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
def describe(self):
|
def describe(self):
|
||||||
s = ""
|
s = ""
|
||||||
@ -313,13 +265,20 @@ class WeightTensors:
|
|||||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||||
|
s += f' - {_describe_tensor(self.w1_gs, "w1_gs")} \n'
|
||||||
|
s += f' - {_describe_tensor(self.w2_gs, "w2_gs")} \n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def is_quantized(self) -> bool:
|
||||||
|
# or w1_scale is not None?
|
||||||
|
return (self.w1.dtype == torch.float8_e4m3fn
|
||||||
|
or self.w1.dtype == torch.uint8 or self.w1.dtype == torch.int8)
|
||||||
|
|
||||||
def to_current_device(self):
|
def to_current_device(self):
|
||||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
|
||||||
if is_quantized:
|
if self.is_quantized():
|
||||||
assert self.w1_scale is not None
|
assert self.w1_scale is not None
|
||||||
assert self.w2_scale is not None
|
assert self.w2_scale is not None
|
||||||
self.w1_scale = self.w1_scale.to(
|
self.w1_scale = self.w1_scale.to(
|
||||||
@ -327,56 +286,51 @@ class WeightTensors:
|
|||||||
self.w2_scale = self.w2_scale.to(
|
self.w2_scale = self.w2_scale.to(
|
||||||
device=torch.cuda.current_device())
|
device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
if self.w1_gs is not None:
|
||||||
|
assert self.w2_gs is not None
|
||||||
|
self.w1_gs = self.w1_gs.to(device=torch.cuda.current_device())
|
||||||
|
self.w2_gs = self.w2_gs.to(device=torch.cuda.current_device())
|
||||||
|
|
||||||
def slice_weights(self, rank: int,
|
def slice_weights(self, rank: int,
|
||||||
num_local_experts: int) -> "WeightTensors":
|
num_local_experts: int) -> "WeightTensors":
|
||||||
s = rank * num_local_experts
|
s = rank * num_local_experts
|
||||||
e = s + num_local_experts
|
e = s + num_local_experts
|
||||||
w1 = self.w1[s:e, :, :]
|
w1 = self.w1[s:e, :, :]
|
||||||
w2 = self.w2[s:e, :, :]
|
w2 = self.w2[s:e, :, :]
|
||||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
|
||||||
w1_scale, w2_scale = (None, None)
|
w1_scale, w2_scale = (None, None)
|
||||||
if is_quantized:
|
if self.is_quantized():
|
||||||
assert self.w1_scale is not None
|
assert self.w1_scale is not None
|
||||||
assert self.w2_scale is not None
|
assert self.w2_scale is not None
|
||||||
w1_scale = self.w1_scale[s:e, :, :]
|
w1_scale = self.w1_scale[s:e, :, :]
|
||||||
w2_scale = self.w2_scale[s:e, :, :]
|
w2_scale = self.w2_scale[s:e, :, :]
|
||||||
return WeightTensors(w1, w2, w1_scale, w2_scale)
|
|
||||||
|
w1_gs = self.w1_gs
|
||||||
|
w2_gs = self.w2_gs
|
||||||
|
if w1_gs is not None:
|
||||||
|
assert w2_gs is not None
|
||||||
|
w1_gs = w1_gs[s:e]
|
||||||
|
w2_gs = w2_gs[s:e]
|
||||||
|
|
||||||
|
return WeightTensors(w1, w2, w1_scale, w2_scale, w1_gs, w2_gs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make(config: Config) -> "WeightTensors":
|
def make(config: Config) -> "WeightTensors":
|
||||||
|
(_, w1, w1_scale, w1_gs), (_, w2, w2_scale, w2_gs) = make_test_weights(
|
||||||
if config.quant_dtype is None:
|
|
||||||
# just make normal dtype weights
|
|
||||||
w1, w2 = make_non_quant_weights(e=config.E,
|
|
||||||
n=config.N,
|
|
||||||
k=config.K,
|
|
||||||
dtype=config.dtype)
|
|
||||||
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
|
|
||||||
|
|
||||||
assert config.quant_dtype == torch.float8_e4m3fn
|
|
||||||
if not config.is_fp8_block_quantized():
|
|
||||||
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
|
|
||||||
e=config.E,
|
|
||||||
n=config.N,
|
|
||||||
k=config.K,
|
|
||||||
per_out_channel_quant=config.is_per_out_ch_quant,
|
|
||||||
)
|
|
||||||
return WeightTensors(w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale)
|
|
||||||
|
|
||||||
assert config.quant_block_shape is not None
|
|
||||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
|
||||||
e=config.E,
|
e=config.E,
|
||||||
n=config.N,
|
n=config.N,
|
||||||
k=config.K,
|
k=config.K,
|
||||||
block_size=config.quant_block_shape,
|
in_dtype=config.dtype,
|
||||||
|
quant_dtype=config.quant_dtype,
|
||||||
|
block_shape=config.quant_block_shape,
|
||||||
|
per_act_token_quant=config.is_per_out_ch_quant,
|
||||||
)
|
)
|
||||||
return WeightTensors(w1=w1,
|
return WeightTensors(w1=w1,
|
||||||
w2=w2,
|
w2=w2,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale)
|
w2_scale=w2_scale,
|
||||||
|
w1_gs=w1_gs,
|
||||||
|
w2_gs=w2_gs)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -449,7 +403,6 @@ class RankTensors:
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||||
False)
|
False)
|
||||||
topk_ids = topk_ids.to(config.topk_ids_dtype)
|
|
||||||
|
|
||||||
# distribute topk_ids evenly
|
# distribute topk_ids evenly
|
||||||
for mi in range(m):
|
for mi in range(m):
|
||||||
@ -457,7 +410,7 @@ class RankTensors:
|
|||||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||||
|
|
||||||
expert_map = None
|
expert_map = None
|
||||||
if config.world_size > 1:
|
if config.world_size > 1 and config.supports_expert_map():
|
||||||
expert_map = torch.full((global_num_experts, ),
|
expert_map = torch.full((global_num_experts, ),
|
||||||
fill_value=-1,
|
fill_value=-1,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
@ -480,92 +433,100 @@ class RankTensors:
|
|||||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||||
rank_tensors: RankTensors) -> torch.Tensor:
|
rank_tensors: RankTensors) -> torch.Tensor:
|
||||||
|
|
||||||
return torch_experts(a=rank_tensors.hidden_states,
|
if config.quant_dtype == "nvfp4":
|
||||||
w1=weights.w1,
|
quant_blocksize = 16
|
||||||
w2=weights.w2,
|
dtype = config.dtype
|
||||||
|
|
||||||
|
w1_q = weights.w1
|
||||||
|
w1_blockscale = weights.w1_scale
|
||||||
|
w1_gs = weights.w1_gs
|
||||||
|
|
||||||
|
w2_q = weights.w2
|
||||||
|
w2_blockscale = weights.w2_scale
|
||||||
|
w2_gs = weights.w2_gs
|
||||||
|
|
||||||
|
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(
|
||||||
|
rank_tensors.hidden_states.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
|
||||||
|
assert w1_gs is not None
|
||||||
|
assert w2_gs is not None
|
||||||
|
assert w1_blockscale is not None
|
||||||
|
assert w2_blockscale is not None
|
||||||
|
|
||||||
|
assert w1_blockscale.shape[1] % 128 == 0
|
||||||
|
assert w1_blockscale.shape[2] % 4 == 0
|
||||||
|
assert w2_blockscale.shape[1] % 128 == 0
|
||||||
|
assert w2_blockscale.shape[2] % 4 == 0
|
||||||
|
|
||||||
|
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(
|
||||||
|
rank_tensors.hidden_states, a_global_scale)
|
||||||
|
|
||||||
|
a = dequantize_nvfp4_to_dtype(a_fp4,
|
||||||
|
a_scale_interleaved,
|
||||||
|
a_global_scale,
|
||||||
|
dtype=dtype,
|
||||||
|
device=a_fp4.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
|
e = w1_q.shape[0]
|
||||||
|
n = w1_q.shape[1] // 2
|
||||||
|
k = w2_q.shape[1]
|
||||||
|
|
||||||
|
w1 = torch.zeros((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||||
|
w2 = torch.zeros((e, k, n), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
for idx in range(0, e):
|
||||||
|
w1[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||||
|
w1_blockscale[idx],
|
||||||
|
w1_gs[idx],
|
||||||
|
dtype=dtype,
|
||||||
|
device=w1_q.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
w2[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||||
|
w2_blockscale[idx],
|
||||||
|
w2_gs[idx],
|
||||||
|
dtype=dtype,
|
||||||
|
device=w2_q.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
a_scale = None
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
quant_dtype = None
|
||||||
|
per_act_token_quant = False
|
||||||
|
block_shape = None
|
||||||
|
else:
|
||||||
|
a = rank_tensors.hidden_states
|
||||||
|
a_scale = rank_tensors.hidden_states_scale
|
||||||
|
w1 = weights.w1
|
||||||
|
w1_scale = weights.w1_scale
|
||||||
|
w2 = weights.w2
|
||||||
|
w2_scale = weights.w2_scale
|
||||||
|
quant_dtype = config.quant_dtype
|
||||||
|
per_act_token_quant = config.is_per_act_token_quant
|
||||||
|
block_shape = config.quant_block_shape
|
||||||
|
|
||||||
|
return torch_experts(a=a,
|
||||||
|
w1=w1,
|
||||||
|
w2=w2,
|
||||||
topk_weight=rank_tensors.topk_weights,
|
topk_weight=rank_tensors.topk_weights,
|
||||||
topk_ids=rank_tensors.topk_ids,
|
topk_ids=rank_tensors.topk_ids,
|
||||||
global_num_experts=config.E,
|
global_num_experts=config.E,
|
||||||
expert_map=None,
|
expert_map=None,
|
||||||
w1_scale=weights.w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=weights.w2_scale,
|
w2_scale=w2_scale,
|
||||||
a1_scale=rank_tensors.hidden_states_scale,
|
a1_scale=a_scale,
|
||||||
quant_dtype=config.quant_dtype,
|
quant_dtype=quant_dtype,
|
||||||
per_act_token_quant=config.is_per_act_token_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
block_shape=config.quant_block_shape,
|
block_shape=block_shape,
|
||||||
apply_router_weights_on_input=config.topk == 1)
|
apply_router_weights_on_input=config.topk == 1
|
||||||
|
and config.supports_apply_weight_on_input())
|
||||||
|
|
||||||
|
|
||||||
def make_fused_experts(
|
def make_modular_kernel(
|
||||||
config: Config, moe: FusedMoEConfig,
|
config: Config,
|
||||||
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
|
vllm_config: VllmConfig,
|
||||||
|
weights: WeightTensors,
|
||||||
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
|
) -> mk.FusedMoEModularKernel:
|
||||||
batch_kwargs = {
|
|
||||||
"max_num_tokens": moe.max_num_tokens,
|
|
||||||
"num_dispatchers": num_dispatchers,
|
|
||||||
}
|
|
||||||
quant_kwargs = {
|
|
||||||
"use_fp8_w8a8": use_fp8,
|
|
||||||
"use_int8_w8a8": False,
|
|
||||||
"use_int8_w8a16": False,
|
|
||||||
"use_int4_w4a16": False,
|
|
||||||
"block_shape": config.quant_block_shape,
|
|
||||||
"per_act_token_quant": config.is_per_act_token_quant,
|
|
||||||
}
|
|
||||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
|
||||||
|
|
||||||
if config.fused_experts_type == BatchedDeepGemmExperts:
|
|
||||||
kwargs = batch_kwargs | {
|
|
||||||
"block_shape": config.quant_block_shape,
|
|
||||||
"per_act_token_quant": config.is_per_act_token_quant,
|
|
||||||
}
|
|
||||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
|
||||||
experts = BatchedDeepGemmExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == BatchedTritonExperts:
|
|
||||||
kwargs = batch_kwargs | quant_kwargs
|
|
||||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
|
||||||
experts = BatchedTritonExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
|
||||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
|
||||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
|
||||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == DeepGemmExperts:
|
|
||||||
print("Making DeepGemmExperts () ...")
|
|
||||||
experts = DeepGemmExperts()
|
|
||||||
elif config.fused_experts_type == TritonExperts:
|
|
||||||
kwargs = quant_kwargs
|
|
||||||
print(f"Making TritonExperts {kwargs} ...")
|
|
||||||
experts = TritonExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == TritonOrDeepGemmExperts:
|
|
||||||
kwargs = quant_kwargs | deepgemm_kwargs
|
|
||||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
|
||||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == NaiveBatchedExperts:
|
|
||||||
kwargs = batch_kwargs | quant_kwargs
|
|
||||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
|
||||||
experts = NaiveBatchedExperts(**kwargs)
|
|
||||||
elif config.fused_experts_type == CutlassExpertsFp8:
|
|
||||||
use_batched_format = config.is_batched_prepare_finalize()
|
|
||||||
num_experts = (moe.num_local_experts
|
|
||||||
if use_batched_format else moe.num_experts)
|
|
||||||
kwargs = {
|
|
||||||
"max_experts_per_worker": num_experts,
|
|
||||||
"out_dtype": moe.in_dtype,
|
|
||||||
"per_act_token_quant": config.is_per_act_token_quant,
|
|
||||||
"per_out_ch_quant": config.is_per_out_ch_quant,
|
|
||||||
"block_shape": config.quant_block_shape,
|
|
||||||
"num_dispatchers": num_dispatchers,
|
|
||||||
"use_batched_format": use_batched_format
|
|
||||||
}
|
|
||||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
|
||||||
experts = CutlassExpertsFp8(**kwargs)
|
|
||||||
|
|
||||||
return experts
|
|
||||||
|
|
||||||
|
|
||||||
def make_modular_kernel(config: Config,
|
|
||||||
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
|
|
||||||
|
|
||||||
def next_power_of_2(x):
|
def next_power_of_2(x):
|
||||||
import math
|
import math
|
||||||
@ -579,6 +540,7 @@ def make_modular_kernel(config: Config,
|
|||||||
dp_size_=get_dp_group().world_size,
|
dp_size_=get_dp_group().world_size,
|
||||||
vllm_parallel_config=vllm_config.parallel_config,
|
vllm_parallel_config=vllm_config.parallel_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
moe = FusedMoEConfig(
|
moe = FusedMoEConfig(
|
||||||
num_experts=config.E,
|
num_experts=config.E,
|
||||||
experts_per_token=config.topk,
|
experts_per_token=config.topk,
|
||||||
@ -591,15 +553,16 @@ def make_modular_kernel(config: Config,
|
|||||||
)
|
)
|
||||||
|
|
||||||
# make modular kernel
|
# make modular kernel
|
||||||
prepare_finalize = None
|
prepare_finalize = make_prepare_finalize(config.prepare_finalize_type,
|
||||||
if config.needs_all2all():
|
config.all2all_backend(), moe)
|
||||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
|
|
||||||
assert prepare_finalize is not None
|
|
||||||
else:
|
|
||||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
|
||||||
|
|
||||||
fused_experts = make_fused_experts(config, moe,
|
fused_experts = make_fused_experts(
|
||||||
prepare_finalize.num_dispatchers())
|
config.fused_experts_type,
|
||||||
|
moe,
|
||||||
|
prepare_finalize.num_dispatchers(),
|
||||||
|
weights.w1_gs,
|
||||||
|
weights.w2_gs,
|
||||||
|
)
|
||||||
|
|
||||||
modular_kernel = mk.FusedMoEModularKernel(
|
modular_kernel = mk.FusedMoEModularKernel(
|
||||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||||
@ -620,22 +583,45 @@ def run_modular_kernel(
|
|||||||
# weights for rank
|
# weights for rank
|
||||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||||
|
|
||||||
mk = make_modular_kernel(config, vllm_config)
|
mk = make_modular_kernel(config, vllm_config, weights)
|
||||||
|
|
||||||
mk_kwargs = {
|
mk_kwargs = {
|
||||||
"hidden_states": rank_tensors.hidden_states.clone(
|
"hidden_states":
|
||||||
|
rank_tensors.hidden_states.clone(
|
||||||
), # impls might update the tensor in place
|
), # impls might update the tensor in place
|
||||||
"w1": rank_weights.w1,
|
"w1":
|
||||||
"w2": rank_weights.w2,
|
rank_weights.w1,
|
||||||
"topk_weights": rank_tensors.topk_weights,
|
"w2":
|
||||||
"topk_ids": rank_tensors.topk_ids,
|
rank_weights.w2,
|
||||||
"expert_map": rank_tensors.expert_map,
|
"topk_weights":
|
||||||
"w1_scale": rank_weights.w1_scale,
|
rank_tensors.topk_weights,
|
||||||
"w2_scale": rank_weights.w2_scale,
|
"topk_ids":
|
||||||
"a1_scale": rank_tensors.hidden_states_scale,
|
rank_tensors.topk_ids.to(mk.prepare_finalize.topk_indices_dtype()),
|
||||||
"global_num_experts": config.E,
|
"expert_map":
|
||||||
"apply_router_weight_on_input": config.topk == 1,
|
rank_tensors.expert_map,
|
||||||
|
"w1_scale":
|
||||||
|
rank_weights.w1_scale,
|
||||||
|
"w2_scale":
|
||||||
|
rank_weights.w2_scale,
|
||||||
|
"a1_scale":
|
||||||
|
rank_tensors.hidden_states_scale,
|
||||||
|
"global_num_experts":
|
||||||
|
config.E,
|
||||||
|
"apply_router_weight_on_input":
|
||||||
|
config.topk == 1 and config.supports_apply_weight_on_input(),
|
||||||
}
|
}
|
||||||
out = mk.forward(**mk_kwargs)
|
|
||||||
|
num_tokens = rank_tensors.hidden_states.shape[0]
|
||||||
|
num_tokens_across_dp = torch.tensor([num_tokens] * config.world_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int)
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
vllm_config,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
):
|
||||||
|
out = mk.forward(**mk_kwargs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|||||||
@ -1,58 +1,316 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Fused experts and PrepareFinalize imports
|
# Fused experts and PrepareFinalize imports
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
BatchedDeepGemmExperts)
|
BatchedDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
BatchedTritonOrDeepGemmExperts)
|
BatchedTritonOrDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
FusedMoEQuantConfig)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedTritonExperts, NaiveBatchedExperts)
|
BatchedTritonExperts, NaiveBatchedExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||||
|
TritonExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||||
TritonOrDeepGemmExperts)
|
TritonOrDeepGemmExperts)
|
||||||
from vllm.utils import has_deep_ep, has_pplx
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
cutlass_fp4_supported)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
cutlass_fp8_supported)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_supported
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
if has_deep_ep():
|
|
||||||
|
@dataclass
|
||||||
|
class PrepareFinalizeInfo:
|
||||||
|
activation_format: mk.FusedMoEActivationFormat
|
||||||
|
supported_dtypes: list[Union[torch.dtype, str]]
|
||||||
|
blocked_quantization_support: bool
|
||||||
|
backend: Optional[str]
|
||||||
|
supports_apply_weight_on_input: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpertInfo:
|
||||||
|
activation_format: mk.FusedMoEActivationFormat
|
||||||
|
supported_dtypes: list[Union[torch.dtype, str]]
|
||||||
|
blocked_quantization_support: bool
|
||||||
|
supports_chunking: bool
|
||||||
|
supports_expert_map: bool
|
||||||
|
needs_matching_quant: bool = False
|
||||||
|
needs_deep_gemm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
PREPARE_FINALIZE_INFO: dict[mk.FusedMoEPrepareAndFinalize,
|
||||||
|
PrepareFinalizeInfo] = {}
|
||||||
|
EXPERT_INFO: dict[mk.FusedMoEPermuteExpertsUnpermute, ExpertInfo] = {}
|
||||||
|
MK_ALL_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||||
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||||
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES: list[mk.FusedMoEPrepareAndFinalize] = []
|
||||||
|
MK_FUSED_EXPERT_TYPES: list[mk.FusedMoEPermuteExpertsUnpermute] = []
|
||||||
|
|
||||||
|
standard_format = mk.FusedMoEActivationFormat.Standard
|
||||||
|
batched_format = mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
common_float_types: list[Union[torch.dtype, str]] = [
|
||||||
|
torch.float8_e4m3fn, torch.bfloat16, torch.float16, torch.float32
|
||||||
|
]
|
||||||
|
common_float_and_int_types = common_float_types + [torch.int8]
|
||||||
|
nv_fp4_types = ["nvfp4"]
|
||||||
|
fp8_types = [torch.float8_e4m3fn]
|
||||||
|
|
||||||
|
|
||||||
|
def register_prepare_and_finalize(
|
||||||
|
kind,
|
||||||
|
activation_format: mk.FusedMoEActivationFormat,
|
||||||
|
supported_dtypes: list[Union[torch.dtype, str]],
|
||||||
|
blocked_quantization_support: bool,
|
||||||
|
backend: Optional[str],
|
||||||
|
force_multigpu: bool = False,
|
||||||
|
supports_apply_weight_on_input: bool = True,
|
||||||
|
):
|
||||||
|
global PREPARE_FINALIZE_INFO
|
||||||
|
global MK_ALL_PREPARE_FINALIZE_TYPES
|
||||||
|
global MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||||
|
global MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES
|
||||||
|
assert kind not in PREPARE_FINALIZE_INFO
|
||||||
|
|
||||||
|
PREPARE_FINALIZE_INFO[kind] = PrepareFinalizeInfo(
|
||||||
|
activation_format,
|
||||||
|
supported_dtypes,
|
||||||
|
blocked_quantization_support,
|
||||||
|
backend,
|
||||||
|
supports_apply_weight_on_input,
|
||||||
|
)
|
||||||
|
MK_ALL_PREPARE_FINALIZE_TYPES.append(kind)
|
||||||
|
if backend is not None or force_multigpu:
|
||||||
|
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||||
|
else:
|
||||||
|
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES.append(kind)
|
||||||
|
|
||||||
|
|
||||||
|
def register_experts(
|
||||||
|
kind,
|
||||||
|
activation_format: mk.FusedMoEActivationFormat,
|
||||||
|
supported_dtypes: list[Union[torch.dtype, str]],
|
||||||
|
blocked_quantization_support: bool,
|
||||||
|
supports_chunking: bool,
|
||||||
|
supports_expert_map: bool,
|
||||||
|
needs_matching_quant: bool = False,
|
||||||
|
needs_deep_gemm: bool = False,
|
||||||
|
):
|
||||||
|
global EXPERT_INFO
|
||||||
|
global MK_FUSED_EXPERT_TYPES
|
||||||
|
assert kind not in EXPERT_INFO
|
||||||
|
|
||||||
|
EXPERT_INFO[kind] = ExpertInfo(
|
||||||
|
activation_format,
|
||||||
|
supported_dtypes,
|
||||||
|
blocked_quantization_support,
|
||||||
|
supports_chunking,
|
||||||
|
supports_expert_map,
|
||||||
|
needs_matching_quant,
|
||||||
|
needs_deep_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
|
MK_FUSED_EXPERT_TYPES.append(kind)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_finalize_info(kind) -> PrepareFinalizeInfo:
|
||||||
|
info = PREPARE_FINALIZE_INFO.get(kind)
|
||||||
|
assert info is not None
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def expert_info(kind) -> ExpertInfo:
|
||||||
|
info = EXPERT_INFO.get(kind)
|
||||||
|
assert info is not None
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
register_prepare_and_finalize(
|
||||||
|
MoEPrepareAndFinalizeNoEP,
|
||||||
|
standard_format,
|
||||||
|
common_float_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
backend=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_experts(
|
||||||
|
BatchedTritonExperts,
|
||||||
|
batched_format,
|
||||||
|
common_float_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=False,
|
||||||
|
supports_expert_map=False,
|
||||||
|
needs_matching_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_experts(
|
||||||
|
TritonExperts,
|
||||||
|
standard_format,
|
||||||
|
common_float_and_int_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=True,
|
||||||
|
supports_expert_map=True,
|
||||||
|
needs_matching_quant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
register_experts(
|
||||||
|
NaiveBatchedExperts,
|
||||||
|
batched_format,
|
||||||
|
common_float_and_int_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=False,
|
||||||
|
supports_expert_map=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Disable on blackwell for now
|
||||||
|
if has_deep_ep() and not current_platform.has_device_capability(100):
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPHTPrepareAndFinalize)
|
DeepEPHTPrepareAndFinalize)
|
||||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
|
|
||||||
|
register_prepare_and_finalize(
|
||||||
|
DeepEPHTPrepareAndFinalize,
|
||||||
|
standard_format,
|
||||||
|
common_float_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
backend="deepep_high_throughput",
|
||||||
|
)
|
||||||
|
|
||||||
|
register_prepare_and_finalize(
|
||||||
|
DeepEPLLPrepareAndFinalize,
|
||||||
|
batched_format,
|
||||||
|
common_float_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
backend="deepep_low_latency",
|
||||||
|
)
|
||||||
|
|
||||||
if has_pplx():
|
if has_pplx():
|
||||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||||
PplxPrepareAndFinalize)
|
PplxPrepareAndFinalize)
|
||||||
|
register_prepare_and_finalize(
|
||||||
|
PplxPrepareAndFinalize,
|
||||||
|
batched_format,
|
||||||
|
common_float_and_int_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
backend="pplx",
|
||||||
|
)
|
||||||
|
|
||||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
|
if (has_flashinfer_cutlass_fused_moe()
|
||||||
if has_pplx():
|
and current_platform.has_device_capability(100)):
|
||||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||||
if has_deep_ep():
|
FlashInferExperts)
|
||||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
]
|
|
||||||
|
|
||||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
|
register_prepare_and_finalize(
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize,
|
||||||
|
standard_format,
|
||||||
|
nv_fp4_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
backend=None,
|
||||||
|
force_multigpu=True,
|
||||||
|
supports_apply_weight_on_input=False,
|
||||||
|
)
|
||||||
|
|
||||||
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
|
register_experts(
|
||||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
FlashInferExperts,
|
||||||
|
standard_format,
|
||||||
|
nv_fp4_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=True,
|
||||||
|
# Note: this is a hack to get it to run for now
|
||||||
|
supports_expert_map=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
FlashInferCutlassMoEPrepareAndFinalize = None
|
||||||
|
|
||||||
MK_FUSED_EXPERT_TYPES = [
|
if has_deep_gemm() and is_deep_gemm_supported():
|
||||||
BatchedDeepGemmExperts,
|
register_experts(
|
||||||
BatchedTritonExperts,
|
BatchedDeepGemmExperts,
|
||||||
NaiveBatchedExperts,
|
batched_format,
|
||||||
BatchedTritonOrDeepGemmExperts,
|
fp8_types,
|
||||||
CutlassExpertsFp8,
|
blocked_quantization_support=True,
|
||||||
DeepGemmExperts,
|
supports_chunking=False,
|
||||||
TritonOrDeepGemmExperts,
|
supports_expert_map=False,
|
||||||
TritonExperts,
|
needs_matching_quant=False,
|
||||||
]
|
needs_deep_gemm=True,
|
||||||
|
)
|
||||||
|
register_experts(
|
||||||
|
DeepGemmExperts,
|
||||||
|
standard_format,
|
||||||
|
fp8_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=True,
|
||||||
|
supports_expert_map=True,
|
||||||
|
needs_matching_quant=False,
|
||||||
|
needs_deep_gemm=True,
|
||||||
|
),
|
||||||
|
register_experts(
|
||||||
|
BatchedTritonOrDeepGemmExperts,
|
||||||
|
batched_format,
|
||||||
|
common_float_and_int_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=False,
|
||||||
|
supports_expert_map=False,
|
||||||
|
needs_matching_quant=True,
|
||||||
|
needs_deep_gemm=True,
|
||||||
|
)
|
||||||
|
register_experts(
|
||||||
|
TritonOrDeepGemmExperts,
|
||||||
|
standard_format,
|
||||||
|
common_float_and_int_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=True,
|
||||||
|
supports_expert_map=True,
|
||||||
|
needs_matching_quant=True,
|
||||||
|
needs_deep_gemm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cutlass_fp8_supported():
|
||||||
|
from vllm.model_executor.layers.fused_moe import (CutlassBatchedExpertsFp8,
|
||||||
|
CutlassExpertsFp8)
|
||||||
|
register_experts(
|
||||||
|
CutlassExpertsFp8,
|
||||||
|
standard_format,
|
||||||
|
fp8_types,
|
||||||
|
blocked_quantization_support=False,
|
||||||
|
supports_chunking=True,
|
||||||
|
supports_expert_map=False,
|
||||||
|
)
|
||||||
|
register_experts(
|
||||||
|
CutlassBatchedExpertsFp8,
|
||||||
|
batched_format,
|
||||||
|
fp8_types,
|
||||||
|
blocked_quantization_support=False,
|
||||||
|
supports_chunking=False,
|
||||||
|
supports_expert_map=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cutlass_fp4_supported():
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
CutlassExpertsFp4)
|
||||||
|
register_experts(
|
||||||
|
CutlassExpertsFp4,
|
||||||
|
standard_format,
|
||||||
|
nv_fp4_types,
|
||||||
|
blocked_quantization_support=True,
|
||||||
|
supports_chunking=True,
|
||||||
|
supports_expert_map=False,
|
||||||
|
)
|
||||||
|
|
||||||
MK_QUANT_CONFIGS = [
|
MK_QUANT_CONFIGS = [
|
||||||
None,
|
None,
|
||||||
@ -85,3 +343,156 @@ MK_QUANT_CONFIGS = [
|
|||||||
# block-quantized weights and per-token activations
|
# block-quantized weights and per-token activations
|
||||||
# block-quantized weights and per-tensor activations
|
# block-quantized weights and per-tensor activations
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if cutlass_fp4_supported() or has_flashinfer_cutlass_fused_moe():
|
||||||
|
MK_QUANT_CONFIGS += [
|
||||||
|
FusedMoEQuantConfig(quant_dtype="nvfp4",
|
||||||
|
per_out_ch_quant=False,
|
||||||
|
per_act_token_quant=False,
|
||||||
|
block_shape=None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gscale(num_experts: int) -> torch.Tensor:
|
||||||
|
return torch.ones((num_experts, ),
|
||||||
|
device=torch.cuda.current_device(),
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def make_prepare_finalize(
|
||||||
|
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
backend: Optional[str],
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
|
if backend != "naive" and backend is not None:
|
||||||
|
prepare_finalize = FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||||
|
assert prepare_finalize is not None
|
||||||
|
return prepare_finalize
|
||||||
|
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||||
|
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||||
|
use_dp=moe.moe_parallel_config.dp_size > 1,
|
||||||
|
a1_gscale=_make_gscale(moe.num_local_experts),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return MoEPrepareAndFinalizeNoEP()
|
||||||
|
|
||||||
|
|
||||||
|
def _slice(rank: int, num_local_experts: int, t: torch.Tensor) -> torch.Tensor:
|
||||||
|
s = rank * num_local_experts
|
||||||
|
e = s + num_local_experts
|
||||||
|
return t[s:e]
|
||||||
|
|
||||||
|
|
||||||
|
def make_fused_experts(
|
||||||
|
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
num_dispatchers: int,
|
||||||
|
w1_gs: Optional[torch.Tensor],
|
||||||
|
w2_gs: Optional[torch.Tensor],
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
|
||||||
|
use_fp8 = moe.quant_dtype == torch.float8_e4m3fn
|
||||||
|
batch_kwargs = {
|
||||||
|
"max_num_tokens": moe.max_num_tokens,
|
||||||
|
"num_dispatchers": num_dispatchers,
|
||||||
|
}
|
||||||
|
quant_kwargs = {
|
||||||
|
"use_fp8_w8a8": use_fp8,
|
||||||
|
"use_int8_w8a8": False,
|
||||||
|
"use_int8_w8a16": False,
|
||||||
|
"use_int4_w4a16": False,
|
||||||
|
"block_shape": moe.block_shape,
|
||||||
|
"per_act_token_quant": moe.per_act_token_quant,
|
||||||
|
}
|
||||||
|
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||||
|
|
||||||
|
if fused_experts_type == BatchedDeepGemmExperts:
|
||||||
|
kwargs = batch_kwargs | {
|
||||||
|
"block_shape": moe.block_shape,
|
||||||
|
"per_act_token_quant": moe.per_act_token_quant,
|
||||||
|
}
|
||||||
|
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||||
|
experts = BatchedDeepGemmExperts(**kwargs)
|
||||||
|
elif fused_experts_type == BatchedTritonExperts:
|
||||||
|
kwargs = batch_kwargs | quant_kwargs
|
||||||
|
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||||
|
experts = BatchedTritonExperts(**kwargs)
|
||||||
|
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||||
|
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||||
|
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||||
|
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||||
|
elif fused_experts_type == DeepGemmExperts:
|
||||||
|
print("Making DeepGemmExperts () ...")
|
||||||
|
experts = DeepGemmExperts()
|
||||||
|
elif fused_experts_type == TritonExperts:
|
||||||
|
kwargs = quant_kwargs
|
||||||
|
print(f"Making TritonExperts {kwargs} ...")
|
||||||
|
experts = TritonExperts(**kwargs)
|
||||||
|
elif fused_experts_type == TritonOrDeepGemmExperts:
|
||||||
|
kwargs = quant_kwargs | deepgemm_kwargs
|
||||||
|
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||||
|
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||||
|
elif fused_experts_type == NaiveBatchedExperts:
|
||||||
|
kwargs = batch_kwargs | quant_kwargs
|
||||||
|
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||||
|
experts = NaiveBatchedExperts(**kwargs)
|
||||||
|
elif fused_experts_type == CutlassExpertsFp8:
|
||||||
|
kwargs = {
|
||||||
|
"out_dtype": moe.in_dtype,
|
||||||
|
"per_act_token_quant": moe.per_act_token_quant,
|
||||||
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||||
|
"block_shape": moe.block_shape,
|
||||||
|
}
|
||||||
|
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||||
|
experts = CutlassExpertsFp8(**kwargs)
|
||||||
|
elif fused_experts_type == CutlassBatchedExpertsFp8:
|
||||||
|
kwargs = {
|
||||||
|
"max_experts_per_worker": moe.num_local_experts,
|
||||||
|
"num_dispatchers": num_dispatchers,
|
||||||
|
"out_dtype": moe.in_dtype,
|
||||||
|
"per_act_token_quant": moe.per_act_token_quant,
|
||||||
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||||
|
"block_shape": moe.block_shape,
|
||||||
|
}
|
||||||
|
print(f"Making CutlassBatchedExpertsFp8 {kwargs} ...")
|
||||||
|
experts = CutlassBatchedExpertsFp8(**kwargs)
|
||||||
|
elif fused_experts_type == CutlassExpertsFp4:
|
||||||
|
assert w1_gs is not None and w2_gs is not None
|
||||||
|
num_experts = moe.num_local_experts
|
||||||
|
rank = moe.moe_parallel_config.dp_rank
|
||||||
|
kwargs = {
|
||||||
|
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||||
|
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||||
|
"a1_gscale": _make_gscale(num_experts),
|
||||||
|
"a2_gscale": _make_gscale(num_experts),
|
||||||
|
"max_experts_per_worker": num_experts,
|
||||||
|
"out_dtype": moe.in_dtype,
|
||||||
|
"per_act_token_quant": moe.per_act_token_quant,
|
||||||
|
"per_out_ch_quant": moe.per_out_ch_quant,
|
||||||
|
"block_shape": moe.block_shape,
|
||||||
|
"num_dispatchers": num_dispatchers,
|
||||||
|
}
|
||||||
|
print(f"Making CutlassExpertsFp4 {kwargs} ...")
|
||||||
|
experts = CutlassExpertsFp4(**kwargs)
|
||||||
|
elif fused_experts_type == FlashInferExperts:
|
||||||
|
assert w1_gs is not None and w2_gs is not None
|
||||||
|
num_experts = moe.num_local_experts
|
||||||
|
rank = moe.moe_parallel_config.dp_rank
|
||||||
|
kwargs = {
|
||||||
|
"g1_alphas": _slice(rank, num_experts, (1 / w1_gs)),
|
||||||
|
"g2_alphas": _slice(rank, num_experts, (1 / w2_gs)),
|
||||||
|
"a1_gscale": _make_gscale(num_experts),
|
||||||
|
"a2_gscale": _make_gscale(num_experts),
|
||||||
|
"out_dtype": moe.in_dtype,
|
||||||
|
"quant_dtype": "nvfp4",
|
||||||
|
"ep_rank": moe.ep_rank,
|
||||||
|
"ep_size": moe.ep_size,
|
||||||
|
"tp_rank": moe.tp_rank,
|
||||||
|
"tp_size": moe.tp_size,
|
||||||
|
}
|
||||||
|
print(f"Making FlashInferExperts {kwargs} ...")
|
||||||
|
experts = FlashInferExperts(**kwargs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown fused experts type: {fused_experts_type}")
|
||||||
|
|
||||||
|
return experts
|
||||||
|
|||||||
@ -52,7 +52,7 @@ def profile_modular_kernel(
|
|||||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||||
|
|
||||||
# make modular kernel
|
# make modular kernel
|
||||||
mk = make_modular_kernel(config, vllm_config)
|
mk = make_modular_kernel(config, vllm_config, weights)
|
||||||
|
|
||||||
mk_kwargs = {
|
mk_kwargs = {
|
||||||
"hidden_states": rank_tensors.hidden_states,
|
"hidden_states": rank_tensors.hidden_states,
|
||||||
@ -83,7 +83,7 @@ def rank_worker(
|
|||||||
# sanity check
|
# sanity check
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
if config.fused_moe_chunk_size is not None:
|
if config.fused_moe_chunk_size is not None:
|
||||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
assert config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
|
|
||||||
# get weights to this device
|
# get weights to this device
|
||||||
weights.to_current_device()
|
weights.to_current_device()
|
||||||
|
|||||||
@ -1,117 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
|
||||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
|
||||||
|
|
||||||
|
|
||||||
def per_token_cast_to_fp8(
|
|
||||||
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 2
|
|
||||||
m, n = x.shape
|
|
||||||
pad_size = (block_size - (n % block_size)) % block_size
|
|
||||||
x = torch.nn.functional.pad(x,
|
|
||||||
(0, pad_size), value=0) if pad_size > 0 else x
|
|
||||||
x_view = x.view(m, -1, block_size)
|
|
||||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
||||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
|
||||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def make_non_quant_weights(
|
|
||||||
e: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return weights w1, w2
|
|
||||||
"""
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
|
|
||||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
|
|
||||||
return w1, w2
|
|
||||||
|
|
||||||
|
|
||||||
def make_block_quant_fp8_weights(
|
|
||||||
e: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
block_size: list[int],
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return weights w1, w2, w1_scale, w2_scale
|
|
||||||
"""
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
device = torch.cuda.current_device()
|
|
||||||
|
|
||||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
|
||||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
|
||||||
|
|
||||||
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
|
|
||||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
|
||||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
|
||||||
|
|
||||||
block_n, block_k = block_size[0], block_size[1]
|
|
||||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
|
||||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
|
||||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
|
||||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
|
||||||
|
|
||||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
|
|
||||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
|
|
||||||
|
|
||||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
|
|
||||||
(k + (block_k - 1)) // block_k)
|
|
||||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
|
||||||
|
|
||||||
for i in range(e):
|
|
||||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
|
||||||
block_size=[block_k, block_n])
|
|
||||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
|
||||||
block_size=[block_k, block_n])
|
|
||||||
|
|
||||||
return w1, w2, w1_s, w2_s
|
|
||||||
|
|
||||||
|
|
||||||
def make_quant_fp8_weights(
|
|
||||||
e: int,
|
|
||||||
n: int,
|
|
||||||
k: int,
|
|
||||||
per_out_channel_quant: bool,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Return w1, w2, w1_scale, w2_scale
|
|
||||||
"""
|
|
||||||
q_dtype = torch.float8_e4m3fn
|
|
||||||
|
|
||||||
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# w1 -> w1_q, w2 -> w2_q
|
|
||||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
|
||||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
|
||||||
|
|
||||||
n_b_scales = 2 * n if per_out_channel_quant else 1
|
|
||||||
k_b_scales = k if per_out_channel_quant else 1
|
|
||||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
for expert in range(e):
|
|
||||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
|
||||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
|
||||||
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
|
||||||
return w1_q, w2_q, w1_scale, w2_scale
|
|
||||||
@ -133,7 +133,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
|
|||||||
per_act_token_quant=per_act_token_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
)
|
)
|
||||||
|
|
||||||
B, B_q, B_scale, _, _, _ = make_test_weights(
|
(B, B_q, B_scale, _), _ = make_test_weights(
|
||||||
num_experts,
|
num_experts,
|
||||||
N // 2,
|
N // 2,
|
||||||
K,
|
K,
|
||||||
@ -243,7 +243,7 @@ def test_fused_moe_batched_experts(
|
|||||||
act_dtype = dtype
|
act_dtype = dtype
|
||||||
quant_dtype = None
|
quant_dtype = None
|
||||||
|
|
||||||
w1_16, w1, w1_s, w2_16, w2, w2_s = make_test_weights(
|
(w1_16, w1, w1_s, _), (w2_16, w2, w2_s, _) = make_test_weights(
|
||||||
e,
|
e,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
|
|||||||
@ -161,18 +161,20 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed,
|
|||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||||
N,
|
_) = make_test_weights(E,
|
||||||
K,
|
N,
|
||||||
dtype,
|
K,
|
||||||
torch.float8_e4m3fn,
|
dtype,
|
||||||
per_act_token_quant=False,
|
torch.float8_e4m3fn,
|
||||||
block_shape=block_size)
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
m_fused_moe = modular_triton_fused_moe(use_fp8_w8a8=True,
|
||||||
use_int8_w8a8=False,
|
use_int8_w8a8=False,
|
||||||
use_int8_w8a16=False,
|
use_int8_w8a16=False,
|
||||||
use_int4_w4a16=False,
|
use_int4_w4a16=False,
|
||||||
|
use_mxfp4_w4a4=False,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=block_size)
|
block_shape=block_size)
|
||||||
|
|
||||||
@ -247,13 +249,14 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed,
|
|||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||||
N,
|
_) = make_test_weights(E,
|
||||||
K,
|
N,
|
||||||
dtype,
|
K,
|
||||||
torch.float8_e4m3fn,
|
dtype,
|
||||||
per_act_token_quant=False,
|
torch.float8_e4m3fn,
|
||||||
block_shape=block_size)
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
# Note: for now use_compile will error out if the problem size is
|
# Note: for now use_compile will error out if the problem size is
|
||||||
# large enough to trigger chunking. I'm leaving the flag and
|
# large enough to trigger chunking. I'm leaving the flag and
|
||||||
|
|||||||
@ -118,13 +118,14 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
|
|||||||
a = torch.randn((M, K), dtype=dtype) / 10
|
a = torch.randn((M, K), dtype=dtype) / 10
|
||||||
score = torch.randn((M, E), dtype=dtype)
|
score = torch.randn((M, E), dtype=dtype)
|
||||||
|
|
||||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(E,
|
(_, w1, w1_s, _), (_, w2, w2_s,
|
||||||
N,
|
_) = make_test_weights(E,
|
||||||
K,
|
N,
|
||||||
dtype,
|
K,
|
||||||
torch.int8,
|
dtype,
|
||||||
per_act_token_quant=False,
|
torch.int8,
|
||||||
block_shape=block_size)
|
per_act_token_quant=False,
|
||||||
|
block_shape=block_size)
|
||||||
|
|
||||||
# Set the context to avoid lots of warning spam.
|
# Set the context to avoid lots of warning spam.
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import random
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import per_token_cast_to_fp8
|
||||||
from tests.kernels.utils import baseline_scaled_mm
|
from tests.kernels.utils import baseline_scaled_mm
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -16,20 +17,6 @@ from vllm.utils import cdiv
|
|||||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||||
|
|
||||||
|
|
||||||
def per_token_cast_to_fp8(
|
|
||||||
x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
assert x.dim() == 2
|
|
||||||
m, n = x.shape
|
|
||||||
pad_size = (128 - (n % 128)) % 128
|
|
||||||
x = torch.nn.functional.pad(x,
|
|
||||||
(0, pad_size), value=0) if pad_size > 0 else x
|
|
||||||
x_view = x.view(m, -1, 128)
|
|
||||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
||||||
fp8_data = (x_view *
|
|
||||||
(448.0 / x_amax.unsqueeze(2))).to(dtype=torch.float8_e4m3fn)
|
|
||||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
|
@pytest.mark.parametrize("num_groups, expected_m_per_group, k, n", [
|
||||||
(4, 8192, 7168, 4096),
|
(4, 8192, 7168, 4096),
|
||||||
(4, 8192, 2048, 7168),
|
(4, 8192, 2048, 7168),
|
||||||
@ -76,7 +63,7 @@ def test_cutlass_grouped_gemm(
|
|||||||
device=device,
|
device=device,
|
||||||
dtype=torch.float))
|
dtype=torch.float))
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
|
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
|
||||||
|
|
||||||
for i in range(num_groups):
|
for i in range(num_groups):
|
||||||
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
|
a = x_fp8[0][ep_offset[i]:ep_offset[i + 1]]
|
||||||
|
|||||||
@ -70,8 +70,10 @@ def make_block_quant_fp8_weights(
|
|||||||
"""
|
"""
|
||||||
Return weights w1q, w2q, w1_scale, w2_scale
|
Return weights w1q, w2q, w1_scale, w2_scale
|
||||||
"""
|
"""
|
||||||
w1, w1q, w1_scale, w2, w2q, w2_scale = make_test_weights(
|
(_, w1q, w1_scale, _), (_, w2q, w2_scale,
|
||||||
e, n, k, torch.bfloat16, torch.float8_e4m3fn, block_size)
|
_) = make_test_weights(e, n, k, torch.bfloat16,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
block_size)
|
||||||
return w1q, w2q, w1_scale, w2_scale
|
return w1q, w2q, w1_scale, w2_scale
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -132,9 +132,9 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
|||||||
# Note: W1 has shape (E, 2N, K), so N = 512
|
# Note: W1 has shape (E, 2N, K), so N = 512
|
||||||
# can trigger the deepgemm path.
|
# can trigger the deepgemm path.
|
||||||
MNKs = [
|
MNKs = [
|
||||||
(1024, 512, 128),
|
(1024, 768, 128),
|
||||||
(1024, 512, 512),
|
(1024, 768, 512),
|
||||||
(2048, 512, 512),
|
(2048, 768, 512),
|
||||||
(512, 1024, 1024),
|
(512, 1024, 1024),
|
||||||
(512, 2048, 2048),
|
(512, 2048, 2048),
|
||||||
(4096, 4096, 1024),
|
(4096, 4096, 1024),
|
||||||
|
|||||||
147
tests/kernels/moe/test_flashinfer_moe.py
Normal file
147
tests/kernels/moe/test_flashinfer_moe.py
Normal file
@ -0,0 +1,147 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import make_test_weights
|
||||||
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
|
FLOAT8_E4M3_MAX,
|
||||||
|
dequantize_nvfp4_to_dtype)
|
||||||
|
from tests.kernels.utils import torch_moe
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
|
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
|
FusedMoEModularKernel)
|
||||||
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
|
MoEPrepareAndFinalizeNoEP)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
|
if not has_flashinfer_cutlass_fused_moe(
|
||||||
|
) or not current_platform.has_device_capability(100):
|
||||||
|
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
MNK_FACTORS = [
|
||||||
|
(2, 1024, 1024),
|
||||||
|
(2, 1024, 1536),
|
||||||
|
(2, 3072, 1024),
|
||||||
|
(2, 3072, 1536),
|
||||||
|
(64, 1024, 1024),
|
||||||
|
(64, 1024, 1536),
|
||||||
|
(64, 3072, 1024),
|
||||||
|
(64, 2048, 1536),
|
||||||
|
(224, 1024, 1024),
|
||||||
|
(224, 1024, 1536),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||||
|
#@pytest.mark.parametrize("e", [128, 256])
|
||||||
|
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16])
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_flashinfer_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
||||||
|
dtype: torch.dtype):
|
||||||
|
current_platform.seed_everything(7)
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
|
pipeline_parallel_size=1))):
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
|
quant_blocksize = 16
|
||||||
|
|
||||||
|
(_, w1_q, w1_blockscale,
|
||||||
|
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||||
|
e,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
in_dtype=dtype,
|
||||||
|
quant_dtype="nvfp4",
|
||||||
|
block_shape=None, # use quant_blocksize?
|
||||||
|
per_act_token_quant=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a,
|
||||||
|
score,
|
||||||
|
topk,
|
||||||
|
renormalize=False)
|
||||||
|
|
||||||
|
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
|
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
assert is_valid_flashinfer_cutlass_fused_moe(a, w1_q, w2_q)
|
||||||
|
|
||||||
|
assert w1_gs is not None
|
||||||
|
assert w2_gs is not None
|
||||||
|
assert w1_blockscale is not None
|
||||||
|
assert w2_blockscale is not None
|
||||||
|
|
||||||
|
flashinfer_experts = FusedMoEModularKernel(
|
||||||
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
|
FlashInferExperts(
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
g1_alphas=(1 / w1_gs),
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
g2_alphas=(1 / w2_gs),
|
||||||
|
out_dtype=dtype,
|
||||||
|
quant_dtype="nvfp4",
|
||||||
|
))
|
||||||
|
|
||||||
|
flashinfer_output = flashinfer_experts(
|
||||||
|
hidden_states=a,
|
||||||
|
w1=w1_q,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2=w2_q,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
a1_scale=a1_gs,
|
||||||
|
a2_scale=a2_gs,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reference check:
|
||||||
|
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
|
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||||
|
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||||
|
_, m_k = a_fp4.shape
|
||||||
|
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||||
|
a_scale_interleaved,
|
||||||
|
a_global_scale,
|
||||||
|
dtype=a.dtype,
|
||||||
|
device=a.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
|
w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype)
|
||||||
|
w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
for idx in range(0, e):
|
||||||
|
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||||
|
w1_blockscale[idx],
|
||||||
|
w1_gs[idx],
|
||||||
|
dtype=dtype,
|
||||||
|
device=w1_q.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||||
|
w2_blockscale[idx],
|
||||||
|
w2_gs[idx],
|
||||||
|
dtype=dtype,
|
||||||
|
device=w2_q.device,
|
||||||
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||||
|
|
||||||
|
torch.testing.assert_close(torch_output,
|
||||||
|
flashinfer_output,
|
||||||
|
atol=1e-1,
|
||||||
|
rtol=1e-1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_flashinfer_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half)
|
||||||
@ -2,6 +2,8 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import textwrap
|
||||||
|
import traceback
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -10,41 +12,51 @@ import torch
|
|||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
|
||||||
BatchedTritonOrDeepGemmExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
|
||||||
BatchedTritonExperts)
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
|
||||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
|
||||||
TritonOrDeepGemmExperts)
|
|
||||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
|
|
||||||
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
from .modular_kernel_tools.common import (Config, RankTensors, WeightTensors,
|
||||||
reference_moe_impl,
|
reference_moe_impl,
|
||||||
run_modular_kernel)
|
run_modular_kernel)
|
||||||
from .modular_kernel_tools.mk_objects import (
|
from .modular_kernel_tools.mk_objects import (
|
||||||
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
MK_FUSED_EXPERT_TYPES, MK_MULTI_GPU_PREPARE_FINALIZE_TYPES,
|
||||||
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
MK_QUANT_CONFIGS, MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, expert_info)
|
||||||
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
from .modular_kernel_tools.parallel_utils import (ProcessGroupInfo,
|
||||||
parallel_launch_with_config)
|
parallel_launch_with_config)
|
||||||
|
|
||||||
# TODO (varun): These requirements are very strict and could be relaxed.
|
has_any_multi_gpu_package = (has_deep_ep() or has_deep_gemm() or has_pplx()
|
||||||
has_all_packages = (has_deep_ep() and has_deep_gemm() and has_pplx())
|
or has_flashinfer_cutlass_fused_moe())
|
||||||
|
|
||||||
meets_package_requirements = pytest.mark.skipif(
|
meets_multi_gpu_requirements = pytest.mark.skipif(
|
||||||
not has_all_packages,
|
not has_any_multi_gpu_package,
|
||||||
reason="Requires deep_ep & deep_gemm & pplx packages",
|
reason="Requires deep_ep or deep_gemm or pplx or flashinfer packages",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_result(verbose, msg, ex=None):
|
||||||
|
if ex is not None:
|
||||||
|
x = str(ex)
|
||||||
|
newx = x.strip(" \n\t")[:16]
|
||||||
|
if len(newx) < len(x):
|
||||||
|
newx = newx + " ..."
|
||||||
|
|
||||||
|
prefix = "E\t"
|
||||||
|
print(f"{textwrap.indent(traceback.format_exc(), prefix)}")
|
||||||
|
print(f"FAILED {msg} - {newx}\n")
|
||||||
|
elif verbose:
|
||||||
|
print(f"PASSED {msg}")
|
||||||
|
else:
|
||||||
|
print(".", end="")
|
||||||
|
|
||||||
|
|
||||||
def rank_worker(
|
def rank_worker(
|
||||||
pgi: ProcessGroupInfo,
|
pgi: ProcessGroupInfo,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
cpu_group,
|
cpu_group,
|
||||||
config: Config,
|
config: Config,
|
||||||
weights: WeightTensors,
|
weights: WeightTensors,
|
||||||
|
verbose: bool,
|
||||||
):
|
):
|
||||||
current_platform.seed_everything(pgi.rank)
|
current_platform.seed_everything(pgi.rank)
|
||||||
|
|
||||||
@ -61,39 +73,64 @@ def rank_worker(
|
|||||||
TOPKs = config.topks
|
TOPKs = config.topks
|
||||||
assert isinstance(TOPKs, list)
|
assert isinstance(TOPKs, list)
|
||||||
|
|
||||||
|
exceptions = []
|
||||||
|
count = 0
|
||||||
|
|
||||||
for m, topk in product(Ms, TOPKs):
|
for m, topk in product(Ms, TOPKs):
|
||||||
print(f"Running m={m}, topk={topk} ...")
|
try:
|
||||||
# override m and topk
|
print(f"Running[{pgi.rank}]: m={m}, topk={topk} ...")
|
||||||
cfgx = copy.deepcopy(config)
|
count = count + 1
|
||||||
cfgx.Ms = m
|
# override m and topk
|
||||||
cfgx.topks = topk
|
cfgx = copy.deepcopy(config)
|
||||||
|
cfgx.Ms = m
|
||||||
|
cfgx.topks = topk
|
||||||
|
|
||||||
# inputs for rank
|
# inputs for rank
|
||||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||||
|
|
||||||
# modular kernel out
|
# modular kernel out
|
||||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||||
rank_tensors)
|
rank_tensors)
|
||||||
|
|
||||||
with set_current_vllm_config(vllm_config):
|
with set_current_vllm_config(vllm_config):
|
||||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||||
|
|
||||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
if config.quant_dtype == "nvfp4":
|
||||||
|
atol = 1e-1
|
||||||
|
rtol = 1e-1
|
||||||
|
else:
|
||||||
|
atol = 3e-2
|
||||||
|
rtol = 3e-2
|
||||||
|
|
||||||
|
torch.testing.assert_close(ref_out, mk_out, atol=atol, rtol=rtol)
|
||||||
|
format_result(verbose, config.describe())
|
||||||
|
except Exception as ex:
|
||||||
|
format_result(verbose, config.describe(), ex)
|
||||||
|
exceptions.append(ex)
|
||||||
|
|
||||||
|
if len(exceptions) > 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{len(exceptions)} of {count} tests failed in child process, "
|
||||||
|
f"rank={pgi.rank}.")
|
||||||
|
else:
|
||||||
|
print(f"{count} of {count} tests passed in child process, "
|
||||||
|
f"rank={pgi.rank}.")
|
||||||
|
|
||||||
|
|
||||||
def run(config: Config):
|
def run(config: Config, verbose: bool):
|
||||||
assert config.is_valid()
|
assert config.is_valid()
|
||||||
print(f"Testing config \n{config.describe()} ...")
|
|
||||||
|
|
||||||
weights: WeightTensors = WeightTensors.make(config)
|
weights: WeightTensors = WeightTensors.make(config)
|
||||||
|
|
||||||
vllm_config, env_dict = config.make_env_data()
|
vllm_config, env_dict = config.make_env_data()
|
||||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||||
env_dict, config, weights)
|
env_dict, config, weights, verbose)
|
||||||
|
|
||||||
|
|
||||||
Ms = [32, 64]
|
Ms = [32, 64]
|
||||||
Ks = [7168] # hidden sizes
|
# hidden sizes, making this too large will cause fp4 tests to fail.
|
||||||
|
# Also needs to be a multiple of 1024 for deep_gemm.
|
||||||
|
Ks = [2048]
|
||||||
Ns = [2048]
|
Ns = [2048]
|
||||||
TOPKs = [4, 1]
|
TOPKs = [4, 1]
|
||||||
Es = [32]
|
Es = [32]
|
||||||
@ -103,19 +140,16 @@ FUSED_MOE_CHUNK_SIZEs = [None, 16]
|
|||||||
|
|
||||||
def is_nyi_config(config: Config) -> bool:
|
def is_nyi_config(config: Config) -> bool:
|
||||||
# We know these configs to be legitimate. but still fail.
|
# We know these configs to be legitimate. but still fail.
|
||||||
|
info = expert_info(config.fused_experts_type)
|
||||||
|
|
||||||
if (config.fused_experts_type in [
|
if info.needs_matching_quant:
|
||||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
|
||||||
TritonExperts, TritonOrDeepGemmExperts
|
|
||||||
]):
|
|
||||||
# The triton kernels expect both per-act-token-quant and
|
# The triton kernels expect both per-act-token-quant and
|
||||||
# per-out-ch-quant or neither.
|
# per-out-ch-quant or neither.
|
||||||
unsupported_quant_config = ((config.is_per_act_token_quant +
|
unsupported_quant_config = ((config.is_per_act_token_quant +
|
||||||
config.is_per_out_ch_quant) == 1)
|
config.is_per_out_ch_quant) == 1)
|
||||||
return unsupported_quant_config
|
return unsupported_quant_config
|
||||||
|
|
||||||
# cutlass kernels dont support expert_maps yet.
|
return not info.supports_expert_map
|
||||||
return config.fused_experts_type == CutlassExpertsFp8
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("k", Ks)
|
@pytest.mark.parametrize("k", Ks)
|
||||||
@ -128,13 +162,13 @@ def is_nyi_config(config: Config) -> bool:
|
|||||||
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||||
@pytest.mark.parametrize("world_size", [2])
|
@pytest.mark.parametrize("world_size", [2])
|
||||||
@meets_package_requirements
|
@meets_multi_gpu_requirements
|
||||||
def test_modular_kernel_combinations_multigpu(
|
def test_modular_kernel_combinations_multigpu(
|
||||||
k: int, n: int, e: int, dtype: torch.dtype,
|
k: int, n: int, e: int, dtype: torch.dtype,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: Optional[FusedMoEQuantConfig],
|
||||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||||
mk.FusedMoEPermuteExpertsUnpermute],
|
mk.FusedMoEPermuteExpertsUnpermute],
|
||||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
Ms=Ms,
|
Ms=Ms,
|
||||||
@ -149,14 +183,15 @@ def test_modular_kernel_combinations_multigpu(
|
|||||||
fused_moe_chunk_size=fused_moe_chunk_size,
|
fused_moe_chunk_size=fused_moe_chunk_size,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not config.is_valid():
|
if not config.is_valid():
|
||||||
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
pytest.skip(f"Tests config {config} is not valid. Skipping ...")
|
||||||
|
|
||||||
if is_nyi_config(config):
|
if is_nyi_config(config):
|
||||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||||
|
|
||||||
print(f"{config.describe()}")
|
verbosity = pytestconfig.getoption('verbose')
|
||||||
run(config)
|
run(config, verbosity > 0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("k", Ks)
|
@pytest.mark.parametrize("k", Ks)
|
||||||
@ -169,13 +204,12 @@ def test_modular_kernel_combinations_multigpu(
|
|||||||
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES))
|
||||||
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs)
|
||||||
@pytest.mark.parametrize("world_size", [1])
|
@pytest.mark.parametrize("world_size", [1])
|
||||||
@meets_package_requirements
|
|
||||||
def test_modular_kernel_combinations_singlegpu(
|
def test_modular_kernel_combinations_singlegpu(
|
||||||
k: int, n: int, e: int, dtype: torch.dtype,
|
k: int, n: int, e: int, dtype: torch.dtype,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: Optional[FusedMoEQuantConfig],
|
||||||
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
combination: tuple[mk.FusedMoEPrepareAndFinalize,
|
||||||
mk.FusedMoEPermuteExpertsUnpermute],
|
mk.FusedMoEPermuteExpertsUnpermute],
|
||||||
fused_moe_chunk_size: Optional[int], world_size: int):
|
fused_moe_chunk_size: Optional[int], world_size: int, pytestconfig):
|
||||||
config = Config(
|
config = Config(
|
||||||
Ms=Ms,
|
Ms=Ms,
|
||||||
K=k,
|
K=k,
|
||||||
@ -196,7 +230,8 @@ def test_modular_kernel_combinations_singlegpu(
|
|||||||
if is_nyi_config(config):
|
if is_nyi_config(config):
|
||||||
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
pytest.skip(f"Tests config {config} is nyi. Skipping ...")
|
||||||
|
|
||||||
run(config)
|
verbosity = pytestconfig.getoption('verbose')
|
||||||
|
run(config, verbosity > 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -211,4 +246,4 @@ if __name__ == '__main__':
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
config = make_config(args)
|
config = make_config(args)
|
||||||
|
|
||||||
run(config)
|
run(config, True)
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.moe.utils import make_test_weights
|
||||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
FLOAT8_E4M3_MAX,
|
FLOAT8_E4M3_MAX,
|
||||||
dequantize_nvfp4_to_dtype)
|
dequantize_nvfp4_to_dtype)
|
||||||
@ -43,41 +44,20 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
VllmConfig(parallel_config=ParallelConfig(
|
VllmConfig(parallel_config=ParallelConfig(
|
||||||
pipeline_parallel_size=1))):
|
pipeline_parallel_size=1))):
|
||||||
|
|
||||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
|
||||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
|
||||||
quant_blocksize = 16
|
quant_blocksize = 16
|
||||||
round_up = lambda x, y: (x + y - 1) // y * y
|
|
||||||
sf_w1_2n = round_up(2 * n, 128)
|
|
||||||
sf_w1_k = round_up(k // quant_blocksize, 4)
|
|
||||||
w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
sf_w2_k = round_up(k, 128)
|
|
||||||
sf_w2_n = round_up(n // quant_blocksize, 4)
|
|
||||||
w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n),
|
|
||||||
device="cuda",
|
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
w1_q = torch.empty((e, 2 * n, k // 2),
|
(_, w1_q, w1_blockscale,
|
||||||
device="cuda",
|
w1_gs), (_, w2_q, w2_blockscale, w2_gs) = make_test_weights(
|
||||||
dtype=torch.uint8)
|
e,
|
||||||
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
|
n,
|
||||||
w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
|
k,
|
||||||
w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32)
|
in_dtype=dtype,
|
||||||
|
quant_dtype="nvfp4",
|
||||||
for expert in range(e):
|
block_shape=None, # use quant_blocksize?
|
||||||
w1_amax = torch.abs(w1).max().to(torch.float32)
|
per_act_token_quant=False,
|
||||||
w2_amax = torch.abs(w2).max().to(torch.float32)
|
)
|
||||||
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
|
||||||
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
|
||||||
|
|
||||||
w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
|
||||||
w1[expert], w1_gs[expert])
|
|
||||||
|
|
||||||
w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
|
||||||
w2[expert], w2_gs[expert])
|
|
||||||
|
|
||||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||||
topk_weights, topk_ids, _ = fused_topk(a,
|
topk_weights, topk_ids, _ = fused_topk(a,
|
||||||
@ -88,6 +68,11 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
assert w1_gs is not None
|
||||||
|
assert w2_gs is not None
|
||||||
|
assert w1_blockscale is not None
|
||||||
|
assert w2_blockscale is not None
|
||||||
|
|
||||||
cutlass_output = cutlass_moe_fp4(
|
cutlass_output = cutlass_moe_fp4(
|
||||||
a=a,
|
a=a,
|
||||||
a1_gscale=a1_gs,
|
a1_gscale=a1_gs,
|
||||||
@ -104,14 +89,13 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
n=n,
|
n=n,
|
||||||
k=k,
|
k=k,
|
||||||
e=e,
|
e=e,
|
||||||
device=a.device,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reference check:
|
# Reference check:
|
||||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||||
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
torch.amax(a.flatten(), dim=-1)).to(torch.float32)
|
||||||
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale)
|
||||||
_, m_k = a_fp4.shape
|
|
||||||
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4,
|
||||||
a_scale_interleaved,
|
a_scale_interleaved,
|
||||||
a_global_scale,
|
a_global_scale,
|
||||||
@ -126,14 +110,14 @@ def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int,
|
|||||||
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx],
|
||||||
w1_blockscale[idx],
|
w1_blockscale[idx],
|
||||||
w1_gs[idx],
|
w1_gs[idx],
|
||||||
dtype=w1.dtype,
|
dtype=dtype,
|
||||||
device=w1.device,
|
device=w1_q.device,
|
||||||
block_size=quant_blocksize)
|
block_size=quant_blocksize)
|
||||||
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx],
|
||||||
w2_blockscale[idx],
|
w2_blockscale[idx],
|
||||||
w2_gs[idx],
|
w2_gs[idx],
|
||||||
dtype=w2.dtype,
|
dtype=dtype,
|
||||||
device=w2.device,
|
device=w2_q.device,
|
||||||
block_size=quant_blocksize)
|
block_size=quant_blocksize)
|
||||||
|
|
||||||
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk)
|
||||||
|
|||||||
@ -9,7 +9,8 @@ import torch
|
|||||||
from tests.kernels.utils import torch_experts
|
from tests.kernels.utils import torch_experts
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
|
CutlassBatchedExpertsFp8)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||||
FusedMoEModularKernel)
|
FusedMoEModularKernel)
|
||||||
@ -123,12 +124,8 @@ def pplx_cutlass_moe(
|
|||||||
num_local_experts=num_local_experts,
|
num_local_experts=num_local_experts,
|
||||||
num_dispatchers=num_dispatchers)
|
num_dispatchers=num_dispatchers)
|
||||||
|
|
||||||
experts = CutlassExpertsFp8(num_local_experts,
|
experts = CutlassBatchedExpertsFp8(num_local_experts, num_dispatchers,
|
||||||
out_dtype,
|
out_dtype, per_act_token, per_out_ch)
|
||||||
per_act_token,
|
|
||||||
per_out_ch,
|
|
||||||
num_dispatchers=num_dispatchers,
|
|
||||||
use_batched_format=True)
|
|
||||||
|
|
||||||
fused_cutlass_experts = FusedMoEModularKernel(
|
fused_cutlass_experts = FusedMoEModularKernel(
|
||||||
prepare_finalize,
|
prepare_finalize,
|
||||||
|
|||||||
@ -770,7 +770,7 @@ def test_pplx_moe_slow(
|
|||||||
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
a = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(
|
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||||
e,
|
e,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
@ -836,7 +836,7 @@ def _pplx_test_loop(pgi: ProcessGroupInfo, dp_size: int, use_internode: bool,
|
|||||||
|
|
||||||
args = dict()
|
args = dict()
|
||||||
if make_weights:
|
if make_weights:
|
||||||
_, w1, w1_s, _, w2, w2_s = make_test_weights(
|
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||||
e,
|
e,
|
||||||
n,
|
n,
|
||||||
k,
|
k,
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||||
|
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||||
|
FLOAT8_E4M3_MAX)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||||
@ -169,28 +171,41 @@ def make_quantized_test_activations(
|
|||||||
def moe_quantize_weights(
|
def moe_quantize_weights(
|
||||||
w: torch.Tensor,
|
w: torch.Tensor,
|
||||||
w_s: Optional[torch.Tensor],
|
w_s: Optional[torch.Tensor],
|
||||||
quant_dtype: Optional[torch.dtype],
|
quant_dtype: Union[torch.dtype, str, None],
|
||||||
per_token_quant: bool,
|
per_token_quant: bool,
|
||||||
block_shape: Optional[list[int]],
|
block_shape: Optional[list[int]],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
assert (quant_dtype == torch.float8_e4m3fn
|
assert (quant_dtype == torch.float8_e4m3fn or quant_dtype == torch.int8
|
||||||
or quant_dtype == torch.int8), "only fp8/int8 supported"
|
or quant_dtype == "nvfp4"), "only fp8/int8/nvfp4 supported"
|
||||||
|
|
||||||
|
w_gs = None
|
||||||
|
|
||||||
if block_shape is not None:
|
if block_shape is not None:
|
||||||
assert not per_token_quant
|
assert not per_token_quant
|
||||||
if quant_dtype == torch.int8:
|
if quant_dtype == torch.int8:
|
||||||
w, w_s = per_block_cast_to_int8(w, block_shape)
|
w, w_s = per_block_cast_to_int8(w, block_shape)
|
||||||
else:
|
elif quant_dtype == torch.float8_e4m3fn:
|
||||||
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
w, w_s = per_block_cast_to_fp8(w, block_shape)
|
||||||
|
elif quant_dtype == "nvfp4":
|
||||||
|
raise RuntimeError("blocked quantization not supported for nvfp4")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||||
else:
|
else:
|
||||||
if quant_dtype == torch.int8:
|
if quant_dtype == torch.int8:
|
||||||
w, w_s = ops.scaled_int8_quant(
|
w, w_s = ops.scaled_int8_quant(
|
||||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||||
else:
|
elif quant_dtype == torch.float8_e4m3fn:
|
||||||
w, w_s = ops.scaled_fp8_quant(
|
w, w_s = ops.scaled_fp8_quant(
|
||||||
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
w, w_s, use_per_token_if_dynamic=per_token_quant)
|
||||||
|
elif quant_dtype == "nvfp4":
|
||||||
|
assert not per_token_quant
|
||||||
|
w_amax = torch.abs(w).max().to(torch.float32)
|
||||||
|
w_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w_amax
|
||||||
|
w, w_s = ops.scaled_fp4_quant(w, w_gs)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported quant type {quant_dtype}")
|
||||||
|
|
||||||
return w, w_s
|
return w, w_s, w_gs
|
||||||
|
|
||||||
|
|
||||||
def make_test_weight(
|
def make_test_weight(
|
||||||
@ -198,21 +213,26 @@ def make_test_weight(
|
|||||||
rows: int,
|
rows: int,
|
||||||
cols: int,
|
cols: int,
|
||||||
in_dtype: torch.dtype = torch.bfloat16,
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
per_act_token_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
w_16 = torch.randn((e, rows, cols), device="cuda", dtype=in_dtype) / 15
|
||||||
|
w_gs = None
|
||||||
|
|
||||||
if quant_dtype is not None:
|
if quant_dtype is not None:
|
||||||
w_l = [None] * e
|
w_l = [None] * e
|
||||||
w_s_l = [None] * e
|
w_s_l = [None] * e
|
||||||
|
w_gs_l = [None] * e
|
||||||
for idx in range(e):
|
for idx in range(e):
|
||||||
w_l[idx], w_s_l[idx] = moe_quantize_weights(
|
w_l[idx], w_s_l[idx], w_gs_l[idx] = moe_quantize_weights(
|
||||||
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
w_16[idx], None, quant_dtype, per_act_token_quant, block_shape)
|
||||||
|
|
||||||
w = torch.stack(w_l)
|
w = torch.stack(w_l)
|
||||||
w_s = torch.stack(w_s_l)
|
w_s = torch.stack(w_s_l)
|
||||||
|
if e > 0 and w_gs_l[0] is not None:
|
||||||
|
w_gs = torch.stack(w_gs_l)
|
||||||
if w_s.ndim == 2:
|
if w_s.ndim == 2:
|
||||||
assert w_s.shape[-1] == 1
|
assert w_s.shape[-1] == 1
|
||||||
w_s = w_s.view(-1, 1, 1)
|
w_s = w_s.view(-1, 1, 1)
|
||||||
@ -225,8 +245,9 @@ def make_test_weight(
|
|||||||
else:
|
else:
|
||||||
w = w_16
|
w = w_16
|
||||||
w_s = None
|
w_s = None
|
||||||
|
w_gs = None
|
||||||
|
|
||||||
return w_16, w, w_s
|
return w_16, w, w_s, w_gs
|
||||||
|
|
||||||
|
|
||||||
def make_test_weights(
|
def make_test_weights(
|
||||||
@ -234,14 +255,30 @@ def make_test_weights(
|
|||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
in_dtype: torch.dtype = torch.bfloat16,
|
in_dtype: torch.dtype = torch.bfloat16,
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
per_act_token_quant: bool = False,
|
per_act_token_quant: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor,
|
) -> tuple[tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
torch.Tensor, Optional[torch.Tensor]]:
|
Optional[torch.Tensor]],
|
||||||
|
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]]:
|
||||||
return (
|
return (
|
||||||
*make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
make_test_weight(e, 2 * n, k, in_dtype, quant_dtype, block_shape,
|
||||||
per_act_token_quant),
|
per_act_token_quant),
|
||||||
*make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
make_test_weight(e, k, n, in_dtype, quant_dtype, block_shape,
|
||||||
per_act_token_quant),
|
per_act_token_quant),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def per_token_cast_to_fp8(
|
||||||
|
x: torch.Tensor,
|
||||||
|
block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert x.dim() == 2
|
||||||
|
m, n = x.shape
|
||||||
|
pad_size = (block_size - (n % block_size)) % block_size
|
||||||
|
x = torch.nn.functional.pad(x,
|
||||||
|
(0, pad_size), value=0) if pad_size > 0 else x
|
||||||
|
x_view = x.view(m, -1, block_size)
|
||||||
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||||
|
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||||
|
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||||
|
|||||||
@ -105,7 +105,8 @@ class DeviceCommunicatorBase:
|
|||||||
# we initialize the all2all manager used in expert parallel.
|
# we initialize the all2all manager used in expert parallel.
|
||||||
use_ep = config.parallel_config.data_parallel_size > 1
|
use_ep = config.parallel_config.data_parallel_size > 1
|
||||||
|
|
||||||
self.use_all2all = "ep" in unique_name and use_ep
|
self.is_ep_communicator = "ep" in unique_name
|
||||||
|
self.use_all2all = self.is_ep_communicator and use_ep
|
||||||
self.all2all_manager: Optional[All2AllManagerBase] = None
|
self.all2all_manager: Optional[All2AllManagerBase] = None
|
||||||
|
|
||||||
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
|
||||||
@ -246,7 +247,7 @@ class DeviceCommunicatorBase:
|
|||||||
"""
|
"""
|
||||||
Prepare the communication buffer for the model.
|
Prepare the communication buffer for the model.
|
||||||
"""
|
"""
|
||||||
if not self.use_all2all:
|
if not self.is_ep_communicator:
|
||||||
return
|
return
|
||||||
|
|
||||||
moe_modules = [
|
moe_modules = [
|
||||||
@ -254,7 +255,7 @@ class DeviceCommunicatorBase:
|
|||||||
if module.__class__.__name__ == "FusedMoE"
|
if module.__class__.__name__ == "FusedMoE"
|
||||||
]
|
]
|
||||||
for module in moe_modules:
|
for module in moe_modules:
|
||||||
module.quant_method.init_prepare_finalize(module.moe_config)
|
module.quant_method.init_prepare_finalize()
|
||||||
|
|
||||||
def dispatch(
|
def dispatch(
|
||||||
self, hidden_states: torch.Tensor,
|
self, hidden_states: torch.Tensor,
|
||||||
|
|||||||
@ -49,7 +49,8 @@ if HAS_TRITON:
|
|||||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||||
BatchedTritonOrDeepGemmExperts)
|
BatchedTritonOrDeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||||
CutlassExpertsFp8, cutlass_moe_fp4, cutlass_moe_fp8)
|
CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4,
|
||||||
|
cutlass_moe_fp8)
|
||||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||||
DeepGemmExperts)
|
DeepGemmExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||||
@ -69,6 +70,7 @@ if HAS_TRITON:
|
|||||||
"cutlass_moe_fp8",
|
"cutlass_moe_fp8",
|
||||||
"cutlass_moe_fp4",
|
"cutlass_moe_fp4",
|
||||||
"CutlassExpertsFp8",
|
"CutlassExpertsFp8",
|
||||||
|
"CutlassBatchedExpertsFp8",
|
||||||
"TritonExperts",
|
"TritonExperts",
|
||||||
"BatchedTritonExperts",
|
"BatchedTritonExperts",
|
||||||
"DeepGemmExperts",
|
"DeepGemmExperts",
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -254,18 +254,28 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||||
return (workspace13, workspace2, output, a.dtype)
|
return (workspace13, workspace2, output, a.dtype)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
hidden_states: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
activation: str,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
expert_map: Optional[torch.Tensor],
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
assert expert_tokens_meta is not None
|
assert expert_tokens_meta is not None
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -132,18 +132,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
|
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
|
||||||
expert_tokens_metadata)
|
expert_tokens_metadata)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
hidden_states: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
activation: str,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
expert_map: Optional[torch.Tensor],
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
experts = (self.batched_deep_gemm_experts
|
experts = (self.batched_deep_gemm_experts
|
||||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||||
assert experts is not None
|
assert experts is not None
|
||||||
@ -151,4 +161,4 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
activation, global_num_experts, expert_map, w1_scale,
|
activation, global_num_experts, expert_map, w1_scale,
|
||||||
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
w2_scale, w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
|
||||||
workspace2, expert_tokens_meta,
|
workspace2, expert_tokens_meta,
|
||||||
apply_router_weight_on_input, extra_expert_args)
|
apply_router_weight_on_input)
|
||||||
|
|||||||
@ -45,7 +45,6 @@ def get_quant_config_weight_quant(
|
|||||||
return _get_quant_config_quantization_args(quant_config, "weights")
|
return _get_quant_config_quantization_args(quant_config, "weights")
|
||||||
|
|
||||||
|
|
||||||
# TODO (bnell): use scalar_type instead of bools?
|
|
||||||
def get_config_quant_dtype(
|
def get_config_quant_dtype(
|
||||||
use_fp8_w8a8: bool,
|
use_fp8_w8a8: bool,
|
||||||
use_int8_w8a8: bool,
|
use_int8_w8a8: bool,
|
||||||
@ -65,7 +64,8 @@ def get_config_quant_dtype(
|
|||||||
@dataclass
|
@dataclass
|
||||||
class FusedMoEQuantConfig:
|
class FusedMoEQuantConfig:
|
||||||
# The post quantization activation type.
|
# The post quantization activation type.
|
||||||
quant_dtype: Optional[torch.dtype] = None
|
# TODO (bnell): use scalar_type instead of Union.
|
||||||
|
quant_dtype: Union[torch.dtype, str, None] = None
|
||||||
per_act_token_quant: bool = False
|
per_act_token_quant: bool = False
|
||||||
per_out_ch_quant: bool = False
|
per_out_ch_quant: bool = False
|
||||||
block_shape: Optional[list[int]] = None
|
block_shape: Optional[list[int]] = None
|
||||||
@ -141,6 +141,7 @@ class FusedMoEQuantConfig:
|
|||||||
use_int8_w8a8,
|
use_int8_w8a8,
|
||||||
use_int8_w8a16,
|
use_int8_w8a16,
|
||||||
use_int4_w4a16,
|
use_int4_w4a16,
|
||||||
|
use_mxfp4_w4a4,
|
||||||
]
|
]
|
||||||
]) <= 1, "Quantization flags are mutually exclusive."
|
]) <= 1, "Quantization flags are mutually exclusive."
|
||||||
|
|
||||||
@ -334,7 +335,7 @@ class FusedMoEConfig:
|
|||||||
assert self.max_num_tokens > 0
|
assert self.max_num_tokens > 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||||
if self.quant_config is not None:
|
if self.quant_config is not None:
|
||||||
return self.quant_config.quant_dtype
|
return self.quant_config.quant_dtype
|
||||||
else:
|
else:
|
||||||
@ -429,7 +430,7 @@ class FusedMoEConfig:
|
|||||||
block_shape = None
|
block_shape = None
|
||||||
per_act_token_quant = False
|
per_act_token_quant = False
|
||||||
per_out_ch_quant = False
|
per_out_ch_quant = False
|
||||||
quant_dtype: Optional[torch.dtype] = None
|
quant_dtype: Union[torch.dtype, str, None] = None
|
||||||
|
|
||||||
input_quant = get_quant_config_input_quant(quant_config)
|
input_quant = get_quant_config_input_quant(quant_config)
|
||||||
weight_quant = get_quant_config_weight_quant(quant_config)
|
weight_quant = get_quant_config_weight_quant(quant_config)
|
||||||
@ -453,7 +454,7 @@ class FusedMoEConfig:
|
|||||||
ModelOptNvFp4Config)
|
ModelOptNvFp4Config)
|
||||||
if quant_dtype is None and isinstance(quant_config,
|
if quant_dtype is None and isinstance(quant_config,
|
||||||
ModelOptNvFp4Config):
|
ModelOptNvFp4Config):
|
||||||
quant_dtype = torch.uint8
|
quant_dtype = "nvfp4"
|
||||||
|
|
||||||
if weight_quant is not None:
|
if weight_quant is not None:
|
||||||
per_out_ch_quant = (
|
per_out_ch_quant = (
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
""" CUTLASS based Fused MoE kernels."""
|
""" CUTLASS based Fused MoE kernels."""
|
||||||
from typing import Any, Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,11 +12,10 @@ from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
|||||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||||
MoEPrepareAndFinalizeNoEP)
|
MoEPrepareAndFinalizeNoEP)
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate, TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
from vllm.model_executor.layers.fused_moe.utils import (_fp8_perm,
|
||||||
_fp8_quantize,
|
_fp8_quantize,
|
||||||
_resize_cache,
|
_resize_cache)
|
||||||
extract_required_args)
|
|
||||||
from vllm.scalar_type import scalar_types
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -213,19 +212,14 @@ def run_cutlass_moe_fp8(
|
|||||||
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
output.copy_(c3[c_map].view(M * topk, K), non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
# TODO (bnell): split class batched vs. non-batched?
|
class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||||
# maybe remove need for passing aq to workspace_shapes
|
|
||||||
class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_experts_per_worker: int,
|
|
||||||
out_dtype: Optional[torch.dtype],
|
out_dtype: Optional[torch.dtype],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
per_out_ch_quant: bool,
|
per_out_ch_quant: bool,
|
||||||
block_shape: Optional[list[int]] = None,
|
block_shape: Optional[list[int]] = None,
|
||||||
num_dispatchers: Optional[int] = None,
|
|
||||||
use_batched_format: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
FusedMoEQuantConfig(
|
FusedMoEQuantConfig(
|
||||||
@ -234,33 +228,84 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
per_out_ch_quant=per_out_ch_quant,
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
))
|
))
|
||||||
assert max_experts_per_worker > 0
|
|
||||||
assert not use_batched_format or num_dispatchers is not None
|
|
||||||
self.max_experts_per_worker = max_experts_per_worker
|
|
||||||
self.num_dispatchers = num_dispatchers
|
|
||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.use_batched_format = use_batched_format
|
|
||||||
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
|
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||||
|
return TopKWeightAndReduceDelegate()
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
activation: str,
|
||||||
|
global_num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
||||||
|
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
||||||
|
|
||||||
|
expert_num_tokens = None
|
||||||
|
if expert_tokens_meta is not None:
|
||||||
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
|
|
||||||
|
activation_callable = lambda o, i: self.activation(activation, o, i)
|
||||||
|
|
||||||
|
use_batched_format = self.activation_formats[
|
||||||
|
0] == mk.FusedMoEActivationFormat.BatchedExperts
|
||||||
|
|
||||||
|
in_dtype = hidden_states.dtype
|
||||||
|
run_cutlass_moe_fp8(
|
||||||
|
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||||
|
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
||||||
|
a2_scale, workspace13, workspace2, expert_num_tokens,
|
||||||
|
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||||
|
self.per_act_token_quant, self.per_out_ch_quant,
|
||||||
|
use_batched_format)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassExpertsFp8(CutlassExpertsFp8Base):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
out_dtype: Optional[torch.dtype],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
per_out_ch_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
out_dtype,
|
||||||
|
per_act_token_quant,
|
||||||
|
per_out_ch_quant,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
self
|
self
|
||||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
if self.use_batched_format:
|
return (mk.FusedMoEActivationFormat.Standard,
|
||||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
mk.FusedMoEActivationFormat.Standard)
|
||||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
|
||||||
else:
|
|
||||||
return (mk.FusedMoEActivationFormat.Standard,
|
|
||||||
mk.FusedMoEActivationFormat.Standard)
|
|
||||||
|
|
||||||
def supports_chunking(self) -> bool:
|
def supports_chunking(self) -> bool:
|
||||||
return not self.use_batched_format
|
return True
|
||||||
|
|
||||||
def supports_expert_map(self) -> bool:
|
def supports_expert_map(self) -> bool:
|
||||||
return not self.use_batched_format
|
return True
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
|
||||||
return TopKWeightAndReduceDelegate()
|
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -274,54 +319,69 @@ class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
local_num_experts: int,
|
local_num_experts: int,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
workspace1: tuple[int, ...] = ()
|
workspace1 = (M * topk, max(N, K))
|
||||||
workspace2: tuple[int, ...] = ()
|
workspace2 = (M * topk, N // 2)
|
||||||
output: tuple[int, ...] = ()
|
output = (M * topk, K)
|
||||||
if self.use_batched_format:
|
|
||||||
padded_M = aq.size(1)
|
|
||||||
num_dp = self.num_dispatchers
|
|
||||||
assert num_dp is not None
|
|
||||||
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
|
||||||
max(N, K))
|
|
||||||
workspace2 = (self.max_experts_per_worker, padded_M * num_dp,
|
|
||||||
(N // 2))
|
|
||||||
output = (self.max_experts_per_worker, padded_M, K)
|
|
||||||
else:
|
|
||||||
workspace1 = (M * topk, max(N, K))
|
|
||||||
workspace2 = (M * topk, N // 2)
|
|
||||||
output = (M * topk, K)
|
|
||||||
return (workspace1, workspace2, output,
|
return (workspace1, workspace2, output,
|
||||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
w1_scale: Optional[torch.Tensor],
|
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
|
||||||
workspace2: torch.Tensor,
|
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
|
||||||
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
|
|
||||||
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"
|
|
||||||
|
|
||||||
expert_num_tokens = None
|
class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base):
|
||||||
if expert_tokens_meta is not None:
|
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
|
||||||
|
|
||||||
activation_callable = lambda o, i: self.activation(activation, o, i)
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_experts_per_worker: int,
|
||||||
|
num_dispatchers: int,
|
||||||
|
out_dtype: Optional[torch.dtype],
|
||||||
|
per_act_token_quant: bool,
|
||||||
|
per_out_ch_quant: bool,
|
||||||
|
block_shape: Optional[list[int]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
out_dtype,
|
||||||
|
per_act_token_quant,
|
||||||
|
per_out_ch_quant,
|
||||||
|
block_shape,
|
||||||
|
)
|
||||||
|
assert max_experts_per_worker > 0
|
||||||
|
self.max_experts_per_worker = max_experts_per_worker
|
||||||
|
self.num_dispatchers = num_dispatchers
|
||||||
|
|
||||||
in_dtype = hidden_states.dtype
|
@property
|
||||||
run_cutlass_moe_fp8(
|
def activation_formats(
|
||||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
self
|
||||||
global_num_experts, expert_map, w1_scale, w2_scale, a1q_scale,
|
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||||
a2_scale, workspace13, workspace2, expert_num_tokens,
|
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||||
self.per_act_token_quant, self.per_out_ch_quant,
|
|
||||||
self.use_batched_format)
|
def supports_chunking(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def supports_expert_map(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# TODO(bnell): maybe remove need for passing aq to workspace_shapes
|
||||||
|
def workspace_shapes(
|
||||||
|
self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
aq: torch.Tensor,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
topk: int,
|
||||||
|
global_num_experts: int,
|
||||||
|
local_num_experts: int,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||||
|
padded_M = aq.size(1)
|
||||||
|
num_dp = self.num_dispatchers
|
||||||
|
assert num_dp is not None
|
||||||
|
workspace1 = (self.max_experts_per_worker, padded_M * num_dp,
|
||||||
|
max(N, K))
|
||||||
|
workspace2 = (self.max_experts_per_worker, padded_M * num_dp, (N // 2))
|
||||||
|
output = (self.max_experts_per_worker, padded_M, K)
|
||||||
|
return (workspace1, workspace2, output,
|
||||||
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_moe_fp8(
|
def cutlass_moe_fp8(
|
||||||
@ -387,11 +447,9 @@ def cutlass_moe_fp8(
|
|||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
CutlassExpertsFp8(
|
CutlassExpertsFp8(
|
||||||
max_experts_per_worker=num_experts,
|
|
||||||
out_dtype=a.dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token_quant=per_act_token,
|
per_act_token_quant=per_act_token,
|
||||||
per_out_ch_quant=per_out_ch,
|
per_out_ch_quant=per_out_ch,
|
||||||
use_batched_format=False,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -476,8 +534,9 @@ def run_cutlass_moe_fp4(
|
|||||||
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
|
||||||
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
e_w2, k_w2, half_n_w2 = w2_fp4.shape
|
||||||
|
|
||||||
assert (e_w1 == e_w2 and e_w1 == e), ("Number of experts must match",
|
assert (e_w1 == e_w2
|
||||||
" between weights.")
|
and e_w1 == e), ("Number of experts must match",
|
||||||
|
f" between weights. {e_w1}, {e_w2}, {e}")
|
||||||
assert (k_a == half_k_w1 * 2
|
assert (k_a == half_k_w1 * 2
|
||||||
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
and k == k_w2), ("Hidden size mismatch between a, w1 and w2")
|
||||||
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
assert (nx2_w1 == n * 2 and half_n_w2 * 2 == n), ("mismatch in "
|
||||||
@ -554,6 +613,10 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
g1_alphas: torch.Tensor,
|
||||||
|
g2_alphas: torch.Tensor,
|
||||||
|
a1_gscale: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
max_experts_per_worker: int,
|
max_experts_per_worker: int,
|
||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
@ -562,8 +625,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
use_batched_format: bool = False,
|
use_batched_format: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
# NVFP4 requires two levels of quantization, which involves
|
||||||
|
# computing some scaling factors dynamically. This makes it
|
||||||
|
# incompatible with the typical prepare -> MoE -> finalize
|
||||||
|
# pipeline. Move the quantization logic into the MoE body.
|
||||||
FusedMoEQuantConfig(
|
FusedMoEQuantConfig(
|
||||||
quant_dtype=torch.uint8,
|
quant_dtype=None, # skip quantization in prepare/finalize
|
||||||
per_act_token_quant=per_act_token_quant,
|
per_act_token_quant=per_act_token_quant,
|
||||||
per_out_ch_quant=per_out_ch_quant,
|
per_out_ch_quant=per_out_ch_quant,
|
||||||
block_shape=block_shape,
|
block_shape=block_shape,
|
||||||
@ -572,6 +639,12 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
self.out_dtype = out_dtype
|
self.out_dtype = out_dtype
|
||||||
self.use_batched_format = use_batched_format
|
self.use_batched_format = use_batched_format
|
||||||
|
|
||||||
|
# TODO(bnell): put this stuff into quant config?
|
||||||
|
self.g1_alphas = g1_alphas
|
||||||
|
self.g2_alphas = g2_alphas
|
||||||
|
self.a1_gscale = a1_gscale
|
||||||
|
self.a2_gscale = a2_gscale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
self
|
self
|
||||||
@ -590,8 +663,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
return TopKWeightAndReduceNoOP()
|
||||||
return TopKWeightAndReduceDelegate()
|
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -620,34 +692,42 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return (workspace1, workspace2, output,
|
return (workspace1, workspace2, output,
|
||||||
self.out_dtype if self.out_dtype is not None else a.dtype)
|
self.out_dtype if self.out_dtype is not None else a.dtype)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor], w1_scale: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w2_scale: torch.Tensor, w1_zp: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
a2_scale: torch.Tensor, workspace13: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
workspace2: Optional[torch.Tensor],
|
topk_ids: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
activation: str,
|
||||||
apply_router_weight_on_input: bool,
|
global_num_experts: int,
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
expert_map: Optional[torch.Tensor],
|
||||||
required_keys = [
|
w1_scale: torch.Tensor,
|
||||||
"g1_alphas", "g2_alphas", "a1_gscale", "a2_gscale", "m", "n", "k",
|
w2_scale: torch.Tensor,
|
||||||
"e", "device"
|
w1_zp: Optional[torch.Tensor],
|
||||||
]
|
w2_zp: Optional[torch.Tensor],
|
||||||
(g1_alphas, g2_alphas, a1_gscale, a2_gscale, m, n, k, e,
|
a1q_scale: Optional[torch.Tensor],
|
||||||
device) = extract_required_args(extra_expert_args, required_keys)
|
a2_scale: torch.Tensor,
|
||||||
|
workspace13: Optional[torch.Tensor],
|
||||||
|
workspace2: Optional[torch.Tensor],
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
|
e, m, n, k, _ = mk._moe_problem_size(hidden_states, w1, w2, topk_ids)
|
||||||
|
n = w2.shape[2] * 2
|
||||||
|
|
||||||
run_cutlass_moe_fp4(
|
run_cutlass_moe_fp4(
|
||||||
output=output,
|
output=output,
|
||||||
a=hidden_states,
|
a=hidden_states,
|
||||||
a1_gscale=a1_gscale,
|
a1_gscale=self.a1_gscale,
|
||||||
w1_fp4=w1,
|
w1_fp4=w1,
|
||||||
w1_blockscale=w1_scale,
|
w1_blockscale=w1_scale,
|
||||||
w1_alphas=g1_alphas,
|
w1_alphas=self.g1_alphas,
|
||||||
a2_gscale=a2_gscale,
|
a2_gscale=self.a2_gscale,
|
||||||
w2_fp4=w2,
|
w2_fp4=w2,
|
||||||
w2_blockscale=w2_scale,
|
w2_blockscale=w2_scale,
|
||||||
w2_alphas=g2_alphas,
|
w2_alphas=self.g2_alphas,
|
||||||
topk_weights=topk_weights,
|
topk_weights=topk_weights,
|
||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
workspace13=workspace13,
|
workspace13=workspace13,
|
||||||
@ -656,7 +736,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
n=n,
|
n=n,
|
||||||
k=k,
|
k=k,
|
||||||
e=e,
|
e=e,
|
||||||
device=device,
|
device=hidden_states.device,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -677,7 +757,6 @@ def cutlass_moe_fp4(
|
|||||||
n: int,
|
n: int,
|
||||||
k: int,
|
k: int,
|
||||||
e: int,
|
e: int,
|
||||||
device: torch.device,
|
|
||||||
expert_map: Optional[torch.Tensor] = None,
|
expert_map: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
apply_router_weight_on_input: bool = False) -> torch.Tensor:
|
||||||
assert expert_map is None, ("Expert Parallelism / expert_map "
|
assert expert_map is None, ("Expert Parallelism / expert_map "
|
||||||
@ -686,6 +765,10 @@ def cutlass_moe_fp4(
|
|||||||
fn = mk.FusedMoEModularKernel(
|
fn = mk.FusedMoEModularKernel(
|
||||||
MoEPrepareAndFinalizeNoEP(),
|
MoEPrepareAndFinalizeNoEP(),
|
||||||
CutlassExpertsFp4(
|
CutlassExpertsFp4(
|
||||||
|
g1_alphas,
|
||||||
|
g2_alphas,
|
||||||
|
a1_gscale,
|
||||||
|
a2_gscale,
|
||||||
max_experts_per_worker=e,
|
max_experts_per_worker=e,
|
||||||
out_dtype=a.dtype,
|
out_dtype=a.dtype,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
@ -693,29 +776,7 @@ def cutlass_moe_fp4(
|
|||||||
use_batched_format=False,
|
use_batched_format=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
extra_expert_args = {
|
|
||||||
'g1_alphas': g1_alphas,
|
|
||||||
'g2_alphas': g2_alphas,
|
|
||||||
'a1_gscale': a1_gscale,
|
|
||||||
'a2_gscale': a2_gscale,
|
|
||||||
'm': m,
|
|
||||||
'n': n,
|
|
||||||
'k': k,
|
|
||||||
'e': e,
|
|
||||||
'device': device,
|
|
||||||
}
|
|
||||||
|
|
||||||
# NVFP4 requires two levels of quantization, which involves computing some
|
|
||||||
# scaling factors dynamically. This makes it incompatible with the typical
|
|
||||||
# prepare -> MoE -> finalize pipeline. Move the quantization logic into the
|
|
||||||
# MoE body.
|
|
||||||
extra_prepare_args = {
|
|
||||||
'skip_quant': True,
|
|
||||||
}
|
|
||||||
# Similar reason as above.
|
|
||||||
extra_finalize_args = {
|
|
||||||
'skip_weight_reduce': True,
|
|
||||||
}
|
|
||||||
return fn(
|
return fn(
|
||||||
hidden_states=a,
|
hidden_states=a,
|
||||||
w1=w1_fp4,
|
w1=w1_fp4,
|
||||||
@ -731,9 +792,6 @@ def cutlass_moe_fp4(
|
|||||||
a1_scale=None,
|
a1_scale=None,
|
||||||
a2_scale=None,
|
a2_scale=None,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=extra_expert_args,
|
|
||||||
extra_prepare_args=extra_prepare_args,
|
|
||||||
extra_finalize_args=extra_finalize_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -824,16 +882,6 @@ def run_cutlass_block_scaled_fused_experts(
|
|||||||
k = w1_q.size(1)
|
k = w1_q.size(1)
|
||||||
n = w2_q.size(1)
|
n = w2_q.size(1)
|
||||||
|
|
||||||
expert_offsets = torch.empty((num_experts + 1, ),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
problem_sizes1 = torch.empty((num_experts, 3),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
problem_sizes2 = torch.empty((num_experts, 3),
|
|
||||||
dtype=torch.int32,
|
|
||||||
device="cuda")
|
|
||||||
|
|
||||||
topk = topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
|
|
||||||
a_q, a1_scale = _fp8_quantize(a,
|
a_q, a1_scale = _fp8_quantize(a,
|
||||||
@ -842,6 +890,16 @@ def run_cutlass_block_scaled_fused_experts(
|
|||||||
block_shape=[128, 128])
|
block_shape=[128, 128])
|
||||||
device = a_q.device
|
device = a_q.device
|
||||||
|
|
||||||
|
expert_offsets = torch.empty((num_experts + 1, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes1 = torch.empty((num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
problem_sizes2 = torch.empty((num_experts, 3),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
|
||||||
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -230,7 +230,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
):
|
):
|
||||||
assert self.block_shape is not None
|
assert self.block_shape is not None
|
||||||
assert a1q_scale is not None
|
assert a1q_scale is not None
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep
|
||||||
import torch
|
import torch
|
||||||
@ -127,12 +127,16 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_topk_weights)
|
expert_topk_weights)
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
self,
|
||||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
topk_ids: torch.Tensor, num_experts: int,
|
a1_scale: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -187,11 +191,15 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
|
||||||
expert_topk_weights)
|
expert_topk_weights)
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import deep_ep
|
import deep_ep
|
||||||
import torch
|
import torch
|
||||||
@ -77,7 +77,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
a1_scale: Optional[torch.Tensor],
|
a1_scale: Optional[torch.Tensor],
|
||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
a1_dtype: torch.dtype,
|
a1_dtype: torch.dtype,
|
||||||
quant_dtype: Optional[torch.dtype],
|
quant_dtype: Union[torch.dtype, str, None],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
block_shape: Optional[list[int]],
|
block_shape: Optional[list[int]],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
@ -111,12 +111,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return x, x_scales
|
return x, x_scales
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
self,
|
||||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
topk_ids: torch.Tensor, num_experts: int,
|
a1_scale: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -162,11 +166,15 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
|
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
) -> None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,8 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceNoOP)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
|
||||||
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
||||||
has_flashinfer_cutlass_fused_moe)
|
has_flashinfer_cutlass_fused_moe)
|
||||||
|
|
||||||
@ -20,7 +19,7 @@ def is_valid_flashinfer_cutlass_fused_moe(hidden_states: torch.Tensor,
|
|||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor) -> bool:
|
w2: torch.Tensor) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
|
Check if the given problem size is supported by the FlashInfer CUTLASS MoE
|
||||||
kernel.
|
kernel.
|
||||||
"""
|
"""
|
||||||
if not has_flashinfer_cutlass_fused_moe():
|
if not has_flashinfer_cutlass_fused_moe():
|
||||||
@ -43,31 +42,34 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
use_nvfp4_w4a4: bool = False,
|
g1_alphas: torch.Tensor,
|
||||||
use_fp8_w8a8: bool = False,
|
g2_alphas: torch.Tensor,
|
||||||
use_dp: bool = False,
|
a1_gscale: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
|
out_dtype: torch.dtype,
|
||||||
|
quant_dtype: Union[torch.dtype, str, None],
|
||||||
ep_rank: int = 0,
|
ep_rank: int = 0,
|
||||||
ep_size: int = 1,
|
ep_size: int = 1,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
num_dispatchers: Optional[int] = None,
|
|
||||||
use_batched_format: bool = False,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
FusedMoEQuantConfig(
|
FusedMoEQuantConfig(
|
||||||
quant_dtype=torch.uint8,
|
quant_dtype=quant_dtype,
|
||||||
per_act_token_quant=False,
|
per_act_token_quant=False,
|
||||||
block_shape=None,
|
block_shape=None,
|
||||||
))
|
))
|
||||||
self.use_nvfp4_w4a4 = use_nvfp4_w4a4
|
assert quant_dtype == "nvfp4", ("Only nvfp4 quantization is "
|
||||||
self.use_fp8_w8a8 = use_fp8_w8a8
|
"currently supported.")
|
||||||
self.ep_rank = ep_rank
|
self.ep_rank = ep_rank
|
||||||
self.ep_size = ep_size
|
self.ep_size = ep_size
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.tp_size = tp_size
|
self.tp_size = tp_size
|
||||||
self.use_dp = use_dp
|
self.g1_alphas = g1_alphas
|
||||||
assert not use_batched_format or num_dispatchers is not None
|
self.g2_alphas = g2_alphas
|
||||||
self.num_dispatchers = num_dispatchers
|
self.a1_gscale = a1_gscale
|
||||||
|
self.a2_gscale = a2_gscale
|
||||||
|
self.out_dtype = out_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -84,8 +86,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
return TopKWeightAndReduceNoOP()
|
||||||
return TopKWeightAndReduceDelegate()
|
|
||||||
|
|
||||||
def workspace_shapes(
|
def workspace_shapes(
|
||||||
self,
|
self,
|
||||||
@ -117,8 +118,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
- Note: in order for activation chunking to work, the first dimension
|
- Note: in order for activation chunking to work, the first dimension
|
||||||
of each tuple must be the number of tokens.
|
of each tuple must be the number of tokens.
|
||||||
"""
|
"""
|
||||||
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
|
||||||
"currently supported.")
|
|
||||||
aq_m, aq_n = aq.shape
|
aq_m, aq_n = aq.shape
|
||||||
workspace2 = ()
|
workspace2 = ()
|
||||||
output_shape = (aq_m, aq_n * 2)
|
output_shape = (aq_m, aq_n * 2)
|
||||||
@ -149,21 +148,9 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: Optional[torch.Tensor],
|
workspace2: Optional[torch.Tensor],
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: Optional[bool],
|
apply_router_weight_on_input: Optional[bool],
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
):
|
):
|
||||||
assert extra_expert_args is not None, \
|
|
||||||
"extra_expert_args must be provided"
|
|
||||||
required_keys = [
|
|
||||||
'g1_alphas', 'g2_alphas', 'a1_gscale', 'a2_gscale', 'out_dtype'
|
|
||||||
]
|
|
||||||
|
|
||||||
g1_alphas, g2_alphas, a1_gscale, a2_gscale, out_dtype = (
|
|
||||||
extract_required_args(extra_expert_args, required_keys))
|
|
||||||
|
|
||||||
# Flashinfer CUTLASS kernel takes scalar global scales,
|
# Flashinfer CUTLASS kernel takes scalar global scales,
|
||||||
# min because inv_scale.
|
# min because inv_scale.
|
||||||
assert self.use_nvfp4_w4a4 is True, ("Only nvfp4 quantization is "
|
|
||||||
"currently supported.")
|
|
||||||
|
|
||||||
# Ensure w1_scale and w2_scale are not None before calling view
|
# Ensure w1_scale and w2_scale are not None before calling view
|
||||||
assert w1_scale is not None and w2_scale is not None, (
|
assert w1_scale is not None and w2_scale is not None, (
|
||||||
@ -171,12 +158,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
"be None for FlashInferExperts")
|
"be None for FlashInferExperts")
|
||||||
|
|
||||||
quant_scales = [
|
quant_scales = [
|
||||||
a1_gscale,
|
self.a1_gscale,
|
||||||
w1_scale.view(torch.int32),
|
w1_scale.view(torch.int32),
|
||||||
g1_alphas,
|
self.g1_alphas,
|
||||||
a2_gscale,
|
self.a2_gscale,
|
||||||
w2_scale.view(torch.int32),
|
w2_scale.view(torch.int32),
|
||||||
g2_alphas,
|
self.g2_alphas,
|
||||||
]
|
]
|
||||||
_ = flashinfer_cutlass_fused_moe(
|
_ = flashinfer_cutlass_fused_moe(
|
||||||
input=hidden_states,
|
input=hidden_states,
|
||||||
@ -185,7 +172,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
# FlashInfer API requires weight to be long for nvfp4
|
# FlashInfer API requires weight to be long for nvfp4
|
||||||
fc1_expert_weights=w1.view(torch.long),
|
fc1_expert_weights=w1.view(torch.long),
|
||||||
fc2_expert_weights=w2.view(torch.long),
|
fc2_expert_weights=w2.view(torch.long),
|
||||||
output_dtype=out_dtype,
|
output_dtype=self.out_dtype,
|
||||||
quant_scales=quant_scales,
|
quant_scales=quant_scales,
|
||||||
input_sf=a1q_scale,
|
input_sf=a1q_scale,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -9,7 +9,7 @@ from vllm.distributed import get_dp_group
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.utils import (
|
from vllm.model_executor.layers.fused_moe.utils import (
|
||||||
extract_required_args, moe_kernel_quantize_input)
|
moe_kernel_quantize_input)
|
||||||
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
from vllm.utils.flashinfer import nvfp4_block_scale_interleave
|
||||||
|
|
||||||
|
|
||||||
@ -21,16 +21,15 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_dtype: Optional[torch.dtype] = None,
|
use_dp: bool,
|
||||||
per_channel_quant: bool = False,
|
a1_gscale: Optional[torch.Tensor],
|
||||||
block_shape: Optional[list[int]] = None,
|
|
||||||
num_dispatchers: int = 1,
|
num_dispatchers: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.per_channel_quant = per_channel_quant
|
|
||||||
self.block_shape = block_shape
|
|
||||||
self.quant_dtype = quant_dtype
|
|
||||||
self.num_dispatchers_ = num_dispatchers
|
self.num_dispatchers_ = num_dispatchers
|
||||||
|
self.use_dp = use_dp
|
||||||
|
self.a1_gscale = a1_gscale
|
||||||
|
self.local_tokens = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
def activation_format(self) -> mk.FusedMoEActivationFormat:
|
||||||
@ -55,10 +54,11 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
num_experts: int,
|
num_experts: int,
|
||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
|
# TODO(bnell): use quant_config + scales instead of ctor args
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor], Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
|
|
||||||
if apply_router_weight_on_input:
|
if apply_router_weight_on_input:
|
||||||
topk = topk_ids.size(1)
|
topk = topk_ids.size(1)
|
||||||
@ -67,22 +67,22 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
(a1_gscale, use_dp, local_tokens) = extract_required_args(
|
|
||||||
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])
|
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1,
|
a1,
|
||||||
a1_gscale,
|
self.a1_gscale,
|
||||||
quant_config.quant_dtype,
|
quant_config.quant_dtype,
|
||||||
self.per_channel_quant,
|
quant_config.per_act_token_quant,
|
||||||
self.block_shape,
|
quant_config.block_shape,
|
||||||
is_fp4_scale_swizzled=not use_dp, # Swizzling after communication
|
# Swizzling after communication
|
||||||
|
is_fp4_scale_swizzled=not self.use_dp,
|
||||||
)
|
)
|
||||||
if use_dp:
|
if self.use_dp:
|
||||||
topk_weights, topk_ids, a1q, a1q_scale = \
|
topk_weights, topk_ids, a1q, a1q_scale = \
|
||||||
get_dp_group().all_gatherv([topk_weights, topk_ids, a1q, a1q_scale], # noqa: E501
|
get_dp_group().all_gatherv(
|
||||||
dim=0,
|
[topk_weights, topk_ids, a1q, a1q_scale],
|
||||||
sizes=get_local_sizes())
|
dim=0,
|
||||||
|
sizes=get_local_sizes(),
|
||||||
|
)
|
||||||
a1_m, a1_n = a1q.shape
|
a1_m, a1_n = a1q.shape
|
||||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||||
|
|
||||||
@ -91,13 +91,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
|
||||||
|
|
||||||
(use_dp,
|
if self.use_dp:
|
||||||
local_tokens) = extract_required_args(extra_finalize_args,
|
|
||||||
['use_dp', 'local_tokens'])
|
|
||||||
if use_dp:
|
|
||||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||||
fused_expert_output, dim=0, sizes=get_local_sizes())
|
fused_expert_output, dim=0, sizes=get_local_sizes())
|
||||||
output.copy_(fused_expert_output)
|
output.copy_(fused_expert_output)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
"""Fused batched MoE kernel."""
|
"""Fused batched MoE kernel."""
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -496,12 +496,16 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return self.num_dispatchers_
|
return self.num_dispatchers_
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
self,
|
||||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
topk_ids: torch.Tensor, num_experts: int,
|
a1_scale: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -590,11 +594,15 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
return b_a1, b_a1_scale, expert_tokens_meta, None, None
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
) -> None:
|
||||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||||
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
weight_and_reduce_impl = TopKWeightAndReduceNaiveBatched(self.rank)
|
||||||
weight_and_reduce_impl.apply(
|
weight_and_reduce_impl.apply(
|
||||||
@ -688,18 +696,28 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
else:
|
else:
|
||||||
return t.to(f32) * group_broadcast(scale, t.shape)
|
return t.to(f32) * group_broadcast(scale, t.shape)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
hidden_states: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
activation: str,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
expert_map: Optional[torch.Tensor],
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
assert hidden_states.dim() == 3
|
assert hidden_states.dim() == 3
|
||||||
assert expert_tokens_meta is not None
|
assert expert_tokens_meta is not None
|
||||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||||
@ -894,18 +912,28 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
output = (num_experts, max_num_tokens * num_dp, K)
|
output = (num_experts, max_num_tokens * num_dp, K)
|
||||||
return (workspace13, workspace2, output, a.dtype)
|
return (workspace13, workspace2, output, a.dtype)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
hidden_states: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
activation: str,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
expert_map: Optional[torch.Tensor],
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.use_int4_w4a16:
|
||||||
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
assert hidden_states.size(-1) // 2 == w1.size(2), (
|
||||||
|
|||||||
@ -1394,9 +1394,9 @@ def fused_experts(hidden_states: torch.Tensor,
|
|||||||
# E8M0 scale, which means we requantize the weight and input to the specific
|
# E8M0 scale, which means we requantize the weight and input to the specific
|
||||||
# scale. Fallen back to cutlass or triton for some cases would cause
|
# scale. Fallen back to cutlass or triton for some cases would cause
|
||||||
# accuracy issue.
|
# accuracy issue.
|
||||||
should_use_deep_gemm = is_blackwell_deep_gemm_e8m0_used(
|
if (allow_deep_gemm and use_fp8_w8a8
|
||||||
) or _valid_deep_gemm(hidden_states, w1, w2)
|
and (is_blackwell_deep_gemm_e8m0_used()
|
||||||
if (allow_deep_gemm and use_fp8_w8a8 and should_use_deep_gemm):
|
or _valid_deep_gemm(hidden_states, w1, w2))):
|
||||||
assert apply_router_weight_on_input is False
|
assert apply_router_weight_on_input is False
|
||||||
assert is_act_and_mul, (
|
assert is_act_and_mul, (
|
||||||
"DeepGemm only supports is_act_and_mul=True for now.")
|
"DeepGemm only supports is_act_and_mul=True for now.")
|
||||||
@ -1905,7 +1905,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
):
|
):
|
||||||
# Check constraints.
|
# Check constraints.
|
||||||
if self.use_int4_w4a16:
|
if self.use_int4_w4a16:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -8,7 +8,6 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
TopKWeightAndReduceDelegate)
|
TopKWeightAndReduceDelegate)
|
||||||
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
|
|
||||||
from vllm.utils import has_triton_kernels
|
from vllm.utils import has_triton_kernels
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -160,12 +159,16 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
num_dispatchers: int,
|
num_dispatchers: int,
|
||||||
w1_precision: "PrecisionConfig",
|
w1_precision: "PrecisionConfig",
|
||||||
w2_precision: "PrecisionConfig",
|
w2_precision: "PrecisionConfig",
|
||||||
|
w1_bias: Optional[torch.Tensor],
|
||||||
|
w2_bias: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
self.max_num_tokens = max_num_tokens
|
self.max_num_tokens = max_num_tokens
|
||||||
self.num_dispatchers = num_dispatchers
|
self.num_dispatchers = num_dispatchers
|
||||||
self.w1_precision = w1_precision
|
self.w1_precision = w1_precision
|
||||||
self.w2_precision = w2_precision
|
self.w2_precision = w2_precision
|
||||||
|
self.w1_bias = w1_bias
|
||||||
|
self.w2_bias = w2_bias
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def activation_formats(
|
def activation_formats(
|
||||||
@ -219,11 +222,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
):
|
):
|
||||||
w1_bias, w2_bias = (extract_required_args(extra_expert_args,
|
|
||||||
["w1_bias", "w2_bias"]))
|
|
||||||
|
|
||||||
return triton_kernel_fused_experts(
|
return triton_kernel_fused_experts(
|
||||||
output,
|
output,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@ -240,8 +239,8 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_scale=w1_scale,
|
w1_scale=w1_scale,
|
||||||
w2_scale=w2_scale,
|
w2_scale=w2_scale,
|
||||||
w1_bias=w1_bias,
|
w1_bias=self.w1_bias,
|
||||||
w2_bias=w2_bias,
|
w2_bias=self.w2_bias,
|
||||||
w1_precision=self.w1_precision,
|
w1_precision=self.w1_precision,
|
||||||
w2_precision=self.w2_precision,
|
w2_precision=self.w2_precision,
|
||||||
a1_scale=a1q_scale,
|
a1_scale=a1q_scale,
|
||||||
|
|||||||
@ -37,7 +37,6 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.platforms.interface import CpuArchEnum
|
from vllm.platforms.interface import CpuArchEnum
|
||||||
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
|
||||||
round_up)
|
round_up)
|
||||||
from vllm.utils.flashinfer import has_flashinfer
|
|
||||||
|
|
||||||
if current_platform.is_cuda_alike():
|
if current_platform.is_cuda_alike():
|
||||||
from .fused_batched_moe import BatchedTritonExperts
|
from .fused_batched_moe import BatchedTritonExperts
|
||||||
@ -49,9 +48,6 @@ if current_platform.is_cuda_alike():
|
|||||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||||
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
from .deepep_ll_prepare_finalize import (DEEPEP_QUANT_BLOCK_SHAPE,
|
||||||
DeepEPLLPrepareAndFinalize)
|
DeepEPLLPrepareAndFinalize)
|
||||||
if has_flashinfer():
|
|
||||||
from .flashinfer_cutlass_prepare_finalize import (
|
|
||||||
FlashInferCutlassMoEPrepareAndFinalize)
|
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
FusedMoEPermuteExpertsUnpermute = None # type: ignore
|
||||||
@ -80,7 +76,12 @@ class FusedMoeWeightScaleSupported(Enum):
|
|||||||
|
|
||||||
class FusedMoEMethodBase(QuantizeMethodBase):
|
class FusedMoEMethodBase(QuantizeMethodBase):
|
||||||
|
|
||||||
moe: FusedMoEConfig
|
# TODO(bnell): also pass quant_config?
|
||||||
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.moe = moe
|
||||||
|
self.fused_experts: Optional[Callable] = None
|
||||||
|
self.topk_indices_dtype = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -99,16 +100,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def maybe_make_prepare_finalize(
|
def _maybe_make_prepare_finalize(
|
||||||
moe: FusedMoEConfig) -> Optional[FusedMoEPrepareAndFinalize]:
|
moe: FusedMoEConfig, ) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||||
assert all2all_manager is not None
|
assert all2all_manager is not None
|
||||||
|
|
||||||
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
prepare_finalize: Optional[FusedMoEPrepareAndFinalize] = None
|
||||||
|
|
||||||
if moe.use_flashinfer_cutlass_kernels:
|
assert not moe.use_flashinfer_cutlass_kernels, \
|
||||||
prepare_finalize = FlashInferCutlassMoEPrepareAndFinalize(
|
"Must be created in modelopt.py"
|
||||||
quant_dtype=moe.quant_dtype, )
|
|
||||||
if moe.use_pplx_kernels:
|
if moe.use_pplx_kernels:
|
||||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||||
moe.max_num_tokens,
|
moe.max_num_tokens,
|
||||||
@ -188,14 +189,25 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
|
|
||||||
return prepare_finalize
|
return prepare_finalize
|
||||||
|
|
||||||
def init_prepare_finalize(self, moe: FusedMoEConfig):
|
def maybe_make_prepare_finalize(
|
||||||
self.moe = moe
|
self,
|
||||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(
|
moe: FusedMoEConfig,
|
||||||
self.moe)
|
) -> Optional[FusedMoEPrepareAndFinalize]:
|
||||||
|
if moe.moe_parallel_config.use_all2all_kernels:
|
||||||
|
return FusedMoEMethodBase._maybe_make_prepare_finalize(moe)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def init_prepare_finalize(self):
|
||||||
|
assert self.moe is not None
|
||||||
|
prepare_finalize = self.maybe_make_prepare_finalize(self.moe)
|
||||||
|
|
||||||
self.topk_indices_dtype = None
|
|
||||||
if prepare_finalize is not None:
|
if prepare_finalize is not None:
|
||||||
logger.debug("%s", prepare_finalize.__class__.__name__)
|
logger.debug("%s for %s(%s)", prepare_finalize.__class__.__name__,
|
||||||
|
self, id(self))
|
||||||
|
assert self.topk_indices_dtype is None
|
||||||
|
assert self.fused_experts is None, \
|
||||||
|
f"Attempt to override experts for {id(self)}!"
|
||||||
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
|
||||||
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
experts = self.select_gemm_impl(prepare_finalize, self.moe)
|
||||||
self.fused_experts = FusedMoEModularKernel(
|
self.fused_experts = FusedMoEModularKernel(
|
||||||
@ -214,12 +226,6 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
f"{self.__class__.__name__} must select appropriate gemm "
|
f"{self.__class__.__name__} must select appropriate gemm "
|
||||||
"implementation based on the prepare_finalize")
|
"implementation based on the prepare_finalize")
|
||||||
|
|
||||||
def maybe_swap_experts_impl(
|
|
||||||
self,
|
|
||||||
moe_parallel_config: FusedMoEParallelConfig,
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -251,10 +257,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
"""MoE method without quantization."""
|
"""MoE method without quantization."""
|
||||||
|
|
||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__()
|
super().__init__(moe)
|
||||||
self.fused_experts = fused_experts # type: ignore
|
|
||||||
self.topk_indices_dtype = None
|
|
||||||
self.moe = moe
|
|
||||||
self.has_bias = self.moe.has_bias
|
self.has_bias = self.moe.has_bias
|
||||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||||
if self.rocm_aiter_moe_enabled:
|
if self.rocm_aiter_moe_enabled:
|
||||||
@ -266,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
def select_gemm_impl(
|
def select_gemm_impl(
|
||||||
self,
|
self,
|
||||||
prepare_finalize: FusedMoEPrepareAndFinalize,
|
prepare_finalize: FusedMoEPrepareAndFinalize,
|
||||||
|
# TODO(bnell): Remove. Every layer should have an moe config object.
|
||||||
moe: FusedMoEConfig,
|
moe: FusedMoEConfig,
|
||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
if (prepare_finalize.activation_format ==
|
if (prepare_finalize.activation_format ==
|
||||||
@ -474,9 +478,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
activation=activation,
|
activation=activation,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
elif self.fused_experts is not None:
|
||||||
# add w1_bias/w2_bias to kwargs if they exist
|
if self.has_bias:
|
||||||
kwargs = dict(
|
raise ValueError(
|
||||||
|
"FusedMoEModularKernel does not support bias.")
|
||||||
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@ -488,17 +494,22 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
)
|
)
|
||||||
if isinstance(self.fused_experts,
|
else:
|
||||||
FusedMoEModularKernel) and self.has_bias:
|
assert fused_experts is not None
|
||||||
raise ValueError(
|
return fused_experts(
|
||||||
"FusedMoEModularKernel does not support bias.")
|
hidden_states=x,
|
||||||
if self.has_bias:
|
w1=layer.w13_weight,
|
||||||
kwargs.update({
|
w2=layer.w2_weight,
|
||||||
"w1_bias": getattr(layer, "w13_bias", None),
|
w1_bias=layer.w13_bias if self.has_bias else None,
|
||||||
"w2_bias": getattr(layer, "w2_bias", None),
|
w2_bias=layer.w2_bias if self.has_bias else None,
|
||||||
})
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
return self.fused_experts(**kwargs)
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self,
|
self,
|
||||||
@ -868,8 +879,6 @@ class FusedMoE(CustomOp):
|
|||||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||||
|
|
||||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||||
if isinstance(self.quant_method, FusedMoEMethodBase):
|
|
||||||
self.quant_method.maybe_swap_experts_impl(self.moe_parallel_config)
|
|
||||||
|
|
||||||
# Chunked all2all staging tensor
|
# Chunked all2all staging tensor
|
||||||
self.batched_hidden_states: Optional[torch.Tensor] = None
|
self.batched_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Any, Optional, final
|
from typing import Optional, final
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -150,15 +150,23 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare(
|
def prepare(
|
||||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
self,
|
||||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
topk_ids: torch.Tensor, num_experts: int,
|
a1_scale: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
) -> tuple[
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
torch.Tensor,
|
||||||
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[ExpertTokensMetadata],
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor],
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Perform any quantization (and/or) dispatching needed
|
Perform any quantization (and/or) dispatching needed
|
||||||
for this kernel.
|
for this kernel.
|
||||||
@ -186,11 +194,15 @@ class FusedMoEPrepareAndFinalize(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: TopKWeightAndReduce,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Perform any combine plus apply weights and perform a reduction on the
|
Perform any combine plus apply weights and perform a reduction on the
|
||||||
fused experts output.
|
fused experts output.
|
||||||
@ -368,7 +380,6 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
|||||||
workspace2: torch.Tensor,
|
workspace2: torch.Tensor,
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This function computes the intermediate result of a Mixture of Experts
|
This function computes the intermediate result of a Mixture of Experts
|
||||||
@ -454,18 +465,27 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
f"{fused_experts.activation_formats[0]}")
|
f"{fused_experts.activation_formats[0]}")
|
||||||
|
|
||||||
def _do_fused_experts(
|
def _do_fused_experts(
|
||||||
self, fused_out: Optional[torch.Tensor], a1: torch.Tensor,
|
self,
|
||||||
a1q: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
|
fused_out: Optional[torch.Tensor],
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
activation: str, global_num_experts: int, local_num_experts: int,
|
a1q: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor], w2_scale: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w1_zp: Optional[torch.Tensor], w2_zp: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a1q_scale: Optional[torch.Tensor],
|
topk_ids: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor],
|
activation: str,
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
local_num_experts: int,
|
||||||
extra_expert_args: Optional[dict[str, Any]]) -> torch.Tensor:
|
expert_map: Optional[torch.Tensor],
|
||||||
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||||
|
|
||||||
@ -509,7 +529,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
workspace2=workspace2,
|
workspace2=workspace2,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=extra_expert_args)
|
)
|
||||||
|
|
||||||
return fused_out
|
return fused_out
|
||||||
|
|
||||||
@ -533,7 +553,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a2_scale: Optional[torch.Tensor],
|
a2_scale: Optional[torch.Tensor],
|
||||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
extra_expert_args: Optional[dict[str, Any]],
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
|
||||||
@ -541,6 +560,9 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||||
num_chunks = cdiv(M, CHUNK_SIZE)
|
num_chunks = cdiv(M, CHUNK_SIZE)
|
||||||
|
|
||||||
|
# TODO(bnell): get rid of one level here, update slice functions
|
||||||
|
# to nops on num_chunks==1
|
||||||
|
|
||||||
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
if not self.fused_experts.supports_chunking() or num_chunks == 1:
|
||||||
return self._do_fused_experts(
|
return self._do_fused_experts(
|
||||||
fused_out=None,
|
fused_out=None,
|
||||||
@ -562,7 +584,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=extra_expert_args)
|
)
|
||||||
|
|
||||||
# Chunking required case
|
# Chunking required case
|
||||||
assert num_chunks > 1
|
assert num_chunks > 1
|
||||||
@ -618,15 +640,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_num_tokens=c_expert_num_tokens,
|
expert_num_tokens=c_expert_num_tokens,
|
||||||
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
expert_num_tokens_cpu=c_expert_num_tokens_cpu)
|
||||||
|
|
||||||
m = None
|
|
||||||
if extra_expert_args is not None and 'm' in extra_expert_args:
|
|
||||||
m = extra_expert_args.get('m')
|
|
||||||
|
|
||||||
if extra_expert_args is not None:
|
|
||||||
chunked_extra_expert_args = extra_expert_args
|
|
||||||
else:
|
|
||||||
chunked_extra_expert_args = {}
|
|
||||||
|
|
||||||
for chunk_idx in range(num_chunks):
|
for chunk_idx in range(num_chunks):
|
||||||
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = (
|
||||||
slice_input_tensors(chunk_idx))
|
slice_input_tensors(chunk_idx))
|
||||||
@ -637,11 +650,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_tokens_meta, c_topk_ids, local_num_experts,
|
expert_tokens_meta, c_topk_ids, local_num_experts,
|
||||||
expert_map)
|
expert_map)
|
||||||
|
|
||||||
s = chunk_idx * CHUNK_SIZE
|
|
||||||
e = min(s + CHUNK_SIZE, M)
|
|
||||||
|
|
||||||
if m is not None:
|
|
||||||
chunked_extra_expert_args['m'] = e - s
|
|
||||||
self._do_fused_experts(
|
self._do_fused_experts(
|
||||||
fused_out=slice_output_tensor(chunk_idx),
|
fused_out=slice_output_tensor(chunk_idx),
|
||||||
a1=a1,
|
a1=a1,
|
||||||
@ -662,7 +670,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a2_scale=c_a2_scale,
|
a2_scale=c_a2_scale,
|
||||||
expert_tokens_meta=c_expert_tokens_meta,
|
expert_tokens_meta=c_expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=chunked_extra_expert_args)
|
)
|
||||||
|
|
||||||
return fused_out
|
return fused_out
|
||||||
|
|
||||||
@ -684,9 +692,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a1_scale: Optional[torch.Tensor] = None,
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
a2_scale: Optional[torch.Tensor] = None,
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
apply_router_weight_on_input: bool = False,
|
apply_router_weight_on_input: bool = False,
|
||||||
extra_expert_args: Optional[dict] = None,
|
|
||||||
extra_prepare_args: Optional[dict] = None,
|
|
||||||
extra_finalize_args: Optional[dict] = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets
|
This function computes a Mixture of Experts (MoE) layer using two sets
|
||||||
@ -719,12 +724,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
- apply_router_weight_on_input (bool): When true, the topk weights are
|
- apply_router_weight_on_input (bool): When true, the topk weights are
|
||||||
applied directly on the inputs. This is only applicable when topk is
|
applied directly on the inputs. This is only applicable when topk is
|
||||||
1.
|
1.
|
||||||
- extra_expert_args (Optional[dict]): Extra keyword arguments to pass to
|
|
||||||
fused_experts.apply.
|
|
||||||
- extra_prepare_args (Optional[dict]): Extra keyword arguments to pass
|
|
||||||
to prepare.
|
|
||||||
- extra_finalize_args (Optional[dict]): Extra keyword arguments to pass
|
|
||||||
to finalize.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||||
@ -748,7 +747,6 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
expert_map,
|
expert_map,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.quant_config,
|
self.fused_experts.quant_config,
|
||||||
extra_prepare_args,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
|
||||||
@ -786,12 +784,15 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
a2_scale=a2_scale,
|
a2_scale=a2_scale,
|
||||||
expert_tokens_meta=expert_tokens_meta,
|
expert_tokens_meta=expert_tokens_meta,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
extra_expert_args=extra_expert_args)
|
)
|
||||||
|
|
||||||
self.prepare_finalize.finalize(
|
self.prepare_finalize.finalize(
|
||||||
output, fused_out, topk_weights, topk_ids,
|
output,
|
||||||
|
fused_out,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
self.fused_experts.finalize_weight_and_reduce_impl(),
|
self.fused_experts.finalize_weight_and_reduce_impl(),
|
||||||
extra_finalize_args)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import pplx_kernels as pplx
|
import pplx_kernels as pplx
|
||||||
import torch
|
import torch
|
||||||
@ -21,7 +21,7 @@ def pplx_hidden_dim_scale_bytes(
|
|||||||
max_num_tokens: int,
|
max_num_tokens: int,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
in_dtype: torch.dtype,
|
in_dtype: torch.dtype,
|
||||||
quant_dtype: Optional[torch.dtype],
|
quant_dtype: Union[torch.dtype, str, None],
|
||||||
per_act_token_quant: bool,
|
per_act_token_quant: bool,
|
||||||
block_shape: Optional[list[int]],
|
block_shape: Optional[list[int]],
|
||||||
):
|
):
|
||||||
@ -32,6 +32,7 @@ def pplx_hidden_dim_scale_bytes(
|
|||||||
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
# ceil_div(hidden_dim, block_size) * sizeof(float32)
|
||||||
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
# For per-token: set to 4 * sizeof(float32) (x4 for alignment)
|
||||||
if quant_dtype is not None:
|
if quant_dtype is not None:
|
||||||
|
assert isinstance(quant_dtype, torch.dtype)
|
||||||
assert quant_dtype.itemsize == 1
|
assert quant_dtype.itemsize == 1
|
||||||
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
hidden_dim_bytes = hidden_dim * quant_dtype.itemsize
|
||||||
elem_size = torch.float32.itemsize
|
elem_size = torch.float32.itemsize
|
||||||
@ -89,12 +90,16 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
return self.num_dispatchers_
|
return self.num_dispatchers_
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
self, a1: torch.Tensor, a1_scale: Optional[torch.Tensor],
|
self,
|
||||||
a2_scale: Optional[torch.Tensor], topk_weights: torch.Tensor,
|
a1: torch.Tensor,
|
||||||
topk_ids: torch.Tensor, num_experts: int,
|
a1_scale: Optional[torch.Tensor],
|
||||||
expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool,
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
expert_map: Optional[torch.Tensor],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]]
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -213,11 +218,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
|||||||
|
|
||||||
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
return expert_x, expert_x_scale, expert_tokens_meta, None, None
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
|
) -> None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
weight_and_reduce_impl, TopKWeightAndReduceDelegate
|
||||||
), ("Weight application and reduction happens in the combine kernel.")
|
), ("Weight application and reduction happens in the combine kernel.")
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -38,7 +38,6 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
expert_map: Optional[torch.Tensor],
|
expert_map: Optional[torch.Tensor],
|
||||||
apply_router_weight_on_input: bool,
|
apply_router_weight_on_input: bool,
|
||||||
quant_config: FusedMoEQuantConfig,
|
quant_config: FusedMoEQuantConfig,
|
||||||
extra_prepare_args: Optional[dict[str, Any]],
|
|
||||||
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
) -> tuple[torch.Tensor, Optional[torch.Tensor],
|
||||||
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor]]:
|
Optional[torch.Tensor]]:
|
||||||
@ -50,32 +49,26 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
|
|||||||
"apply_router_weight_on_input is only implemented for topk=1"
|
"apply_router_weight_on_input is only implemented for topk=1"
|
||||||
a1.mul_(topk_weights.to(a1.dtype))
|
a1.mul_(topk_weights.to(a1.dtype))
|
||||||
|
|
||||||
if (extra_prepare_args is not None
|
|
||||||
and extra_prepare_args.get("skip_quant", True)):
|
|
||||||
# Skip quantization if explicitly requested
|
|
||||||
return a1, None, None, None, None
|
|
||||||
|
|
||||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||||
a1, a1_scale, quant_config.quant_dtype,
|
a1, a1_scale, quant_config.quant_dtype,
|
||||||
quant_config.per_act_token_quant, quant_config.block_shape)
|
quant_config.per_act_token_quant, quant_config.block_shape)
|
||||||
|
|
||||||
return a1q, a1q_scale, None, None, None
|
return a1q, a1q_scale, None, None, None
|
||||||
|
|
||||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
def finalize(
|
||||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
self,
|
||||||
apply_router_weight_on_input: bool,
|
output: torch.Tensor,
|
||||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
fused_expert_output: torch.Tensor,
|
||||||
extra_finalize_args: Optional[dict[str, Any]]) -> None:
|
topk_weights: torch.Tensor,
|
||||||
if (extra_finalize_args is not None
|
topk_ids: torch.Tensor,
|
||||||
and extra_finalize_args.get("skip_weight_reduce", True)):
|
apply_router_weight_on_input: bool,
|
||||||
assert output.shape == fused_expert_output.shape
|
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||||
output.copy_(fused_expert_output)
|
) -> None:
|
||||||
else:
|
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
||||||
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
|
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
||||||
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
|
weight_and_reduce_impl.apply(
|
||||||
weight_and_reduce_impl.apply(
|
output=output,
|
||||||
output=output,
|
fused_expert_output=fused_expert_output,
|
||||||
fused_expert_output=fused_expert_output,
|
topk_weights=topk_weights,
|
||||||
topk_weights=topk_weights,
|
topk_ids=topk_ids,
|
||||||
topk_ids=topk_ids,
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Any, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -119,18 +119,28 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta)
|
expert_tokens_meta)
|
||||||
|
|
||||||
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
|
def apply(
|
||||||
w1: torch.Tensor, w2: torch.Tensor, topk_weights: torch.Tensor,
|
self,
|
||||||
topk_ids: torch.Tensor, activation: str, global_num_experts: int,
|
output: torch.Tensor,
|
||||||
expert_map: Optional[torch.Tensor],
|
hidden_states: torch.Tensor,
|
||||||
w1_scale: Optional[torch.Tensor],
|
w1: torch.Tensor,
|
||||||
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
|
w2: torch.Tensor,
|
||||||
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
|
topk_weights: torch.Tensor,
|
||||||
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
|
topk_ids: torch.Tensor,
|
||||||
workspace2: torch.Tensor,
|
activation: str,
|
||||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
global_num_experts: int,
|
||||||
apply_router_weight_on_input: bool,
|
expert_map: Optional[torch.Tensor],
|
||||||
extra_expert_args: Optional[dict[str, Any]]):
|
w1_scale: Optional[torch.Tensor],
|
||||||
|
w2_scale: Optional[torch.Tensor],
|
||||||
|
w1_zp: Optional[torch.Tensor],
|
||||||
|
w2_zp: Optional[torch.Tensor],
|
||||||
|
a1q_scale: Optional[torch.Tensor],
|
||||||
|
a2_scale: Optional[torch.Tensor],
|
||||||
|
workspace13: torch.Tensor,
|
||||||
|
workspace2: torch.Tensor,
|
||||||
|
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||||
|
apply_router_weight_on_input: bool,
|
||||||
|
):
|
||||||
use_deep_gemm = (self.allow_deep_gemm
|
use_deep_gemm = (self.allow_deep_gemm
|
||||||
and (_valid_deep_gemm(hidden_states, w1, w2)
|
and (_valid_deep_gemm(hidden_states, w1, w2)
|
||||||
or is_blackwell_deep_gemm_e8m0_used()))
|
or is_blackwell_deep_gemm_e8m0_used()))
|
||||||
@ -158,5 +168,4 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
workspace2,
|
workspace2,
|
||||||
expert_tokens_meta,
|
expert_tokens_meta,
|
||||||
apply_router_weight_on_input,
|
apply_router_weight_on_input,
|
||||||
extra_expert_args,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from math import prod
|
from math import prod
|
||||||
from typing import Any, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -189,7 +189,7 @@ def moe_kernel_quantize_input(
|
|||||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
elif quant_dtype == torch.int8:
|
elif quant_dtype == torch.int8:
|
||||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||||
elif quant_dtype == torch.uint8: # nvfp4
|
elif quant_dtype == "nvfp4":
|
||||||
return _fp4_quantize(A,
|
return _fp4_quantize(A,
|
||||||
A_scale,
|
A_scale,
|
||||||
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
is_sf_swizzled_layout=is_fp4_scale_swizzled)
|
||||||
@ -252,17 +252,3 @@ def _validate_scale_shape(
|
|||||||
assert block_shape is not None
|
assert block_shape is not None
|
||||||
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
expected = (a.shape[0], cdiv(a.shape[1], block_shape[1]))
|
||||||
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
assert a_scale.shape == expected, f"{a_scale.shape} == {expected}"
|
||||||
|
|
||||||
|
|
||||||
def extract_required_args(
|
|
||||||
extra_args: Optional[dict[str, Any]],
|
|
||||||
required_keys: list[str],
|
|
||||||
) -> tuple[Any, ...]:
|
|
||||||
if extra_args is None:
|
|
||||||
raise ValueError("`extra_args` must be provided.")
|
|
||||||
|
|
||||||
missing_keys = [k for k in required_keys if k not in extra_args]
|
|
||||||
if missing_keys:
|
|
||||||
raise ValueError(f"Missing keys in `extra_args`: {missing_keys}")
|
|
||||||
|
|
||||||
return tuple(extra_args[k] for k in required_keys)
|
|
||||||
|
|||||||
@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
|
|||||||
|
|
||||||
if isinstance(layer, FusedMoE):
|
if isinstance(layer, FusedMoE):
|
||||||
if use_marlin:
|
if use_marlin:
|
||||||
return AWQMoEMethod(quant_args_marlin)
|
return AWQMoEMethod(quant_args_marlin, layer.moe)
|
||||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||||
MoeWNA16Config)
|
MoeWNA16Config)
|
||||||
|
|
||||||
@ -339,7 +339,7 @@ class AutoRoundConfig(QuantizationConfig):
|
|||||||
}
|
}
|
||||||
return MoeWNA16Config.from_config(config).get_quant_method(
|
return MoeWNA16Config.from_config(config).get_quant_method(
|
||||||
layer, prefix)
|
layer, prefix)
|
||||||
return GPTQMarlinMoEMethod(quant_args_marlin)
|
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
|
||||||
|
|
||||||
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
if isinstance(layer, (LinearBase, ParallelLMHead)):
|
||||||
if use_marlin:
|
if use_marlin:
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class AWQConfig(QuantizationConfig):
|
|||||||
}
|
}
|
||||||
awq_marlin_config = AWQMarlinConfig.from_config(
|
awq_marlin_config = AWQMarlinConfig.from_config(
|
||||||
marlin_compatible_config_dict)
|
marlin_compatible_config_dict)
|
||||||
return AWQMoEMethod(awq_marlin_config)
|
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||||
UnquantizedFusedMoEMethod)
|
UnquantizedFusedMoEMethod)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
@ -151,7 +151,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
"Falling back to Moe WNA16 kernels.")
|
"Falling back to Moe WNA16 kernels.")
|
||||||
return MoeWNA16Config.from_config(
|
return MoeWNA16Config.from_config(
|
||||||
self.full_config).get_quant_method(layer, prefix)
|
self.full_config).get_quant_method(layer, prefix)
|
||||||
return AWQMoEMethod(self)
|
return AWQMoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -328,7 +328,12 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
class AWQMoEMethod(FusedMoEMethodBase):
|
class AWQMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, quant_config: AWQMarlinConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: AWQMarlinConfig,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
if self.quant_config.weight_bits != 4:
|
if self.quant_config.weight_bits != 4:
|
||||||
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
raise ValueError("AWQMoEMethod only supports 4bit now.")
|
||||||
@ -500,6 +505,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `AWQMoEMethod` yet.")
|
"EPLB not supported for `AWQMoEMethod` yet.")
|
||||||
@ -516,7 +523,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
@ -535,4 +543,4 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
w1_zeros=layer.w13_qzeros,
|
w1_zeros=layer.w13_qzeros,
|
||||||
w2_zeros=layer.w2_qzeros,
|
w2_zeros=layer.w2_qzeros,
|
||||||
workspace=layer.workspace)
|
workspace=layer.workspace)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
|
FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
@ -132,7 +133,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return BitsAndBytesLinearMethod(self)
|
return BitsAndBytesLinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return BitsAndBytesMoEMethod(self)
|
return BitsAndBytesMoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -411,7 +412,12 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The BitsAndBytes quantization config.
|
quant_config: The BitsAndBytes quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: BitsAndBytesConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: BitsAndBytesConfig,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
try:
|
try:
|
||||||
import bitsandbytes
|
import bitsandbytes
|
||||||
if version.parse(
|
if version.parse(
|
||||||
@ -422,7 +428,6 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
||||||
"`pip install bitsandbytes>=0.46.1` to use "
|
"`pip install bitsandbytes>=0.46.1` to use "
|
||||||
"bitsandbytes quantizer.") from err
|
"bitsandbytes quantizer.") from err
|
||||||
self.topk_indices_dtype = None
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
@ -470,6 +475,7 @@ class BitsAndBytesMoEMethod(FusedMoEMethodBase):
|
|||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -11,20 +11,21 @@ from compressed_tensors.quantization import (ActivationOrdering,
|
|||||||
QuantizationStrategy)
|
QuantizationStrategy)
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
|
||||||
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
FlashInferCutlassMoEPrepareAndFinalize)
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
|
||||||
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
WNA16_SUPPORTED_BITS, WNA16_SUPPORTED_TYPES_MAP)
|
||||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
select_nvfp4_gemm_impl)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
check_moe_marlin_supports_layer, marlin_make_workspace_new,
|
||||||
marlin_moe_permute_scales)
|
marlin_moe_permute_scales)
|
||||||
@ -58,6 +59,9 @@ __all__ = [
|
|||||||
|
|
||||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
def __init_(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_moe_method(
|
def get_moe_method(
|
||||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
@ -81,18 +85,22 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
"WNA16MoE is not supported with actorder=group/dynamic."
|
"WNA16MoE is not supported with actorder=group/dynamic."
|
||||||
)
|
)
|
||||||
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
logger.info_once("Using CompressedTensorsWNA16MoEMethod")
|
||||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
return CompressedTensorsWNA16MoEMethod(quant_config,
|
||||||
|
layer.moe_config)
|
||||||
else:
|
else:
|
||||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||||
return CompressedTensorsWNA16MarlinMoEMethod(quant_config)
|
return CompressedTensorsWNA16MarlinMoEMethod(
|
||||||
|
quant_config, layer.moe_config)
|
||||||
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A4MoeMethod()
|
return CompressedTensorsW4A4MoeMethod(layer.moe_config, layer)
|
||||||
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
or quant_config._is_fp8_w8a8_sm100(weight_quant, input_quant)
|
||||||
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
or quant_config._is_fp8_w8a8(weight_quant, input_quant)):
|
||||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config,
|
||||||
|
layer.moe_config)
|
||||||
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Int8MoEMethod(quant_config,
|
||||||
|
layer.moe_config)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
|
||||||
@ -100,15 +108,16 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, moe: FusedMoEConfig, layer: torch.nn.Module):
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
detect_nvfp4_moe_support)
|
detect_nvfp4_moe_support)
|
||||||
|
super().__init__(moe)
|
||||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
self.use_marlin = _nvfp4.use_marlin
|
self.use_marlin = _nvfp4.use_marlin
|
||||||
self.group_size = 16
|
self.group_size = 16
|
||||||
self.fused_experts = None # type: ignore[assignment]
|
self.layer = layer
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
hidden_size: int, intermediate_size_per_partition: int,
|
hidden_size: int, intermediate_size_per_partition: int,
|
||||||
@ -265,19 +274,36 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_input_scale_quant = torch.nn.Parameter(
|
layer.w2_input_scale_quant = torch.nn.Parameter(
|
||||||
(layer.w2_input_global_scale), requires_grad=False)
|
(layer.w2_input_global_scale), requires_grad=False)
|
||||||
|
|
||||||
def maybe_swap_experts_impl(self, moe_parallel_config):
|
def maybe_make_prepare_finalize(
|
||||||
|
self,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
if not self.allow_flashinfer:
|
if not self.allow_flashinfer:
|
||||||
return
|
return super().maybe_make_prepare_finalize(moe)
|
||||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
|
||||||
moe_parallel_config)
|
|
||||||
|
|
||||||
def select_gemm_impl(self, prepare_finalize, moe):
|
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||||
|
moe,
|
||||||
|
a1_gscale=self.layer.w13_input_scale_quant,
|
||||||
|
)
|
||||||
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
|
return prepare_finalize
|
||||||
|
|
||||||
|
def select_gemm_impl(
|
||||||
|
self,
|
||||||
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return the appropriate GEMM experts implementation."""
|
"""Return the appropriate GEMM experts implementation."""
|
||||||
assert moe is not None and prepare_finalize is not None
|
experts = select_nvfp4_gemm_impl(
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
moe,
|
||||||
select_nvfp4_gemm_impl)
|
g1_alphas=self.layer.g1_alphas,
|
||||||
|
g2_alphas=self.layer.g2_alphas,
|
||||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
a1_gscale=self.layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=self.layer.w2_input_scale_quant,
|
||||||
|
allow_flashinfer=self.allow_flashinfer,
|
||||||
|
)
|
||||||
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
|
return experts
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
@ -301,6 +327,8 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError("EPLB not supported for "
|
raise NotImplementedError("EPLB not supported for "
|
||||||
"`CompressedTensorsW4A4MoeMethod` yet.")
|
"`CompressedTensorsW4A4MoeMethod` yet.")
|
||||||
@ -317,6 +345,7 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
@ -340,15 +369,22 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
# FlashInfer fused experts path
|
# FlashInfer fused experts path
|
||||||
if self.fused_experts is not None:
|
if self.fused_experts is not None:
|
||||||
return flashinfer_fp4_cutlass_moe_forward(
|
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||||
self.fused_experts,
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
layer,
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
x,
|
|
||||||
topk_weights,
|
return self.fused_experts(
|
||||||
topk_ids,
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
|
w1_scale=layer.w13_blockscale_swizzled,
|
||||||
|
w2_scale=layer.w2_blockscale_swizzled,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -376,7 +412,6 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
n=layer.w2_weight.shape[2] * 2,
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
k=x.shape[1],
|
k=x.shape[1],
|
||||||
e=layer.w13_weight.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
device=x.device,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
apply_router_weight_on_input=apply_router_weight_on_input).to(
|
||||||
x.dtype)
|
x.dtype)
|
||||||
|
|
||||||
@ -384,15 +419,16 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
|||||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig,
|
||||||
):
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
"weights")
|
"weights")
|
||||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
"input_activations")
|
"input_activations")
|
||||||
self.topk_indices_dtype = None
|
|
||||||
|
|
||||||
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
||||||
and self.input_quant.strategy
|
and self.input_quant.strategy
|
||||||
@ -429,7 +465,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
self.weight_quant, self.input_quant)
|
self.weight_quant, self.input_quant)
|
||||||
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
self.use_cutlass = (quant_config._is_fp8_w8a8_sm90(
|
||||||
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
self.weight_quant, self.input_quant) or self.is_fp8_w8a8_sm100)
|
||||||
self.fused_experts = None # type: ignore[assignment]
|
|
||||||
self.disable_expert_map = False
|
self.disable_expert_map = False
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -614,25 +649,31 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
) -> FusedMoEPermuteExpertsUnpermute:
|
) -> FusedMoEPermuteExpertsUnpermute:
|
||||||
# cutlass path
|
# cutlass path
|
||||||
if self.use_cutlass:
|
if self.use_cutlass:
|
||||||
from vllm.model_executor.layers.fused_moe import CutlassExpertsFp8
|
from vllm.model_executor.layers.fused_moe import (
|
||||||
|
CutlassBatchedExpertsFp8, CutlassExpertsFp8)
|
||||||
|
|
||||||
use_batched_format = (prepare_finalize.activation_format ==
|
experts: FusedMoEPermuteExpertsUnpermute
|
||||||
FusedMoEActivationFormat.BatchedExperts)
|
|
||||||
|
|
||||||
num_dispatchers = prepare_finalize.num_dispatchers()
|
num_dispatchers = prepare_finalize.num_dispatchers()
|
||||||
num_experts = (moe.num_local_experts
|
|
||||||
if use_batched_format else moe.num_experts)
|
|
||||||
|
|
||||||
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
if (prepare_finalize.activation_format ==
|
||||||
|
FusedMoEActivationFormat.BatchedExperts):
|
||||||
experts = CutlassExpertsFp8(
|
logger.debug("CutlassBatchedExpertsFp8(%s)",
|
||||||
num_experts,
|
self.__class__.__name__)
|
||||||
moe.in_dtype,
|
experts = CutlassBatchedExpertsFp8(
|
||||||
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
moe.num_local_experts,
|
||||||
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
num_dispatchers,
|
||||||
num_dispatchers=num_dispatchers,
|
moe.in_dtype,
|
||||||
use_batched_format=use_batched_format,
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
)
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("CutlassExpertsFp8(%s)", self.__class__.__name__)
|
||||||
|
experts = CutlassExpertsFp8(
|
||||||
|
moe.in_dtype,
|
||||||
|
self.input_quant.strategy == QuantizationStrategy.TOKEN,
|
||||||
|
self.weight_quant.strategy == QuantizationStrategy.CHANNEL,
|
||||||
|
)
|
||||||
|
|
||||||
self.disable_expert_map = (num_dispatchers > 1
|
self.disable_expert_map = (num_dispatchers > 1
|
||||||
or not experts.supports_expert_map())
|
or not experts.supports_expert_map())
|
||||||
@ -834,9 +875,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig,
|
||||||
):
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||||
"weights")
|
"weights")
|
||||||
@ -934,6 +977,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for "
|
"EPLB not supported for "
|
||||||
@ -951,7 +996,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
@ -975,9 +1021,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig,
|
||||||
):
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
# are supported + check if the layer is being ignored.
|
# are supported + check if the layer is being ignored.
|
||||||
@ -1233,6 +1281,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for "
|
"EPLB not supported for "
|
||||||
@ -1251,7 +1301,8 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
@ -1279,9 +1330,11 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
|
|||||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||||
|
moe: FusedMoEConfig,
|
||||||
):
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||||
# are supported + check if the layer is being ignored.
|
# are supported + check if the layer is being ignored.
|
||||||
@ -1459,6 +1512,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError("EPLB not supported for "
|
raise NotImplementedError("EPLB not supported for "
|
||||||
"`CompressedTensorsWNA16MoEMethod` yet.")
|
"`CompressedTensorsWNA16MoEMethod` yet.")
|
||||||
@ -1475,7 +1530,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -6,7 +6,8 @@ from typing import Any, Callable, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -46,13 +47,18 @@ class ExpertsInt8Config(QuantizationConfig):
|
|||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return ExpertsInt8MoEMethod(self)
|
return ExpertsInt8MoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, quant_config: ExpertsInt8Config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: ExpertsInt8Config,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -122,6 +128,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
|
"EPLB not supported for `ExpertsInt8MoEMethod` yet.")
|
||||||
@ -138,7 +146,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import functools
|
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional
|
from typing import TYPE_CHECKING, Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -142,7 +141,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return UnquantizedLinearMethod()
|
return UnquantizedLinearMethod()
|
||||||
return Fp8LinearMethod(self)
|
return Fp8LinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return Fp8MoEMethod(self)
|
return Fp8MoEMethod(self, layer.moe_config)
|
||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return Fp8KVCacheMethod(self)
|
return Fp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
@ -479,9 +478,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The quantization config.
|
quant_config: The quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.block_quant = self.quant_config.weight_block_size is not None
|
self.block_quant = self.quant_config.weight_block_size is not None
|
||||||
|
|
||||||
@ -529,15 +527,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
"CutlassBlockScaledGroupedGemm not supported on the current "
|
"CutlassBlockScaledGroupedGemm not supported on the current "
|
||||||
"platform.")
|
"platform.")
|
||||||
|
|
||||||
self.topk_indices_dtype = None
|
|
||||||
self.fused_experts = functools.partial( # type: ignore
|
|
||||||
fused_experts,
|
|
||||||
use_fp8_w8a8=True,
|
|
||||||
block_shape=self.quant_config.weight_block_size,
|
|
||||||
allow_deep_gemm=self.allow_deep_gemm,
|
|
||||||
allow_cutlass_block_scaled_grouped_gemm=(
|
|
||||||
self.allow_cutlass_block_scaled_grouped_gemm))
|
|
||||||
|
|
||||||
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
|
||||||
intermediate_size_per_partition: int,
|
intermediate_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
@ -1033,7 +1022,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
elif self.fused_experts is not None:
|
||||||
return self.fused_experts(
|
return self.fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@ -1052,6 +1041,30 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
a1_scale=layer.w13_input_scale,
|
a1_scale=layer.w13_input_scale,
|
||||||
a2_scale=layer.w2_input_scale,
|
a2_scale=layer.w2_input_scale,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
return fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
activation=activation,
|
||||||
|
global_num_experts=global_num_experts,
|
||||||
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
|
expert_map=expert_map,
|
||||||
|
w1_scale=(layer.w13_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w13_weight_scale),
|
||||||
|
w2_scale=(layer.w2_weight_scale_inv
|
||||||
|
if self.block_quant else layer.w2_weight_scale),
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
block_shape=self.quant_config.weight_block_size,
|
||||||
|
allow_deep_gemm=self.allow_deep_gemm,
|
||||||
|
allow_cutlass_block_scaled_grouped_gemm=(
|
||||||
|
self.allow_cutlass_block_scaled_grouped_gemm))
|
||||||
|
|
||||||
|
|
||||||
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
class Fp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from torch.nn.parameter import Parameter, UninitializedParameter
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
|
FusedMoEConfig,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -58,7 +59,7 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
elif isinstance(layer, VocabParallelEmbedding):
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
return GGUFEmbeddingMethod(self)
|
return GGUFEmbeddingMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return GGUFMoEMethod(self)
|
return GGUFMoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -445,7 +446,12 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The GGUF quantization config.
|
quant_config: The GGUF quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: GGUFConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: GGUFConfig,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -525,6 +531,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `GGUFMoEMethod` yet.")
|
"EPLB not supported for `GGUFMoEMethod` yet.")
|
||||||
@ -545,7 +553,8 @@ class GGUFMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
return fused_moe_gguf(x, layer.w13_qweight, layer.w2_qweight,
|
||||||
topk_weights, topk_ids,
|
topk_weights, topk_ids,
|
||||||
layer.w13_qweight_type.weight_type,
|
layer.w13_qweight_type.weight_type,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ import vllm.model_executor.layers.fused_moe # noqa
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
|
||||||
UnquantizedFusedMoEMethod)
|
UnquantizedFusedMoEMethod)
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
@ -375,7 +375,12 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
|||||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||||
"""MoE Marlin method with quantization."""
|
"""MoE Marlin method with quantization."""
|
||||||
|
|
||||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: GPTQMarlinConfig,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
if self.quant_config.quant_type.size_bits == 4:
|
if self.quant_config.quant_type.size_bits == 4:
|
||||||
self.quant_type = scalar_types.uint4b8
|
self.quant_type = scalar_types.uint4b8
|
||||||
@ -646,6 +651,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
"EPLB not supported for `GPTQMarlinMoEMethod` yet.")
|
||||||
@ -662,7 +669,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -12,7 +12,9 @@ import vllm.envs as envs
|
|||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
|
is_valid_flashinfer_cutlass_fused_moe)
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
@ -22,8 +24,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
|
||||||
build_flashinfer_fp4_cutlass_moe_kernel,
|
build_flashinfer_fp4_cutlass_moe_prepare_finalize, reorder_w1w3_to_w3w1,
|
||||||
flashinfer_fp4_cutlass_moe_forward, reorder_w1w3_to_w3w1)
|
select_nvfp4_gemm_impl)
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||||
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
apply_flashinfer_per_tensor_scale_fp8, register_moe_scaling_factors,
|
||||||
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
rotate_flashinfer_fp8_moe_weights, swap_w13_to_w31)
|
||||||
@ -177,7 +179,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return ModelOptFp8MoEMethod(self)
|
return ModelOptFp8MoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -273,7 +275,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
quant_config: The ModelOpt quantization config.
|
quant_config: The ModelOpt quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptFp8Config) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: ModelOptFp8Config,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
cutlass_fp8_supported)
|
cutlass_fp8_supported)
|
||||||
@ -454,6 +461,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
"EPLB not supported for `ModelOptFp8MoEMethod` yet.")
|
||||||
@ -484,6 +493,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
|||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
fused_experts)
|
fused_experts)
|
||||||
@ -699,7 +709,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
elif isinstance(layer, Attention):
|
elif isinstance(layer, Attention):
|
||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return ModelOptNvFp4FusedMoE(self)
|
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -923,10 +933,17 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
quant_config: NVFP4 Quant Config
|
quant_config: NVFP4 Quant Config
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: ModelOptNvFp4Config) -> None:
|
def __init__(
|
||||||
self.quant_config = quant_config
|
self,
|
||||||
|
quant_config: ModelOptNvFp4Config,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> None:
|
||||||
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( # noqa: E501
|
||||||
detect_nvfp4_moe_support)
|
detect_nvfp4_moe_support)
|
||||||
|
super().__init__(moe)
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.layer = layer
|
||||||
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
_nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__)
|
||||||
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported
|
||||||
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
self.allow_flashinfer = _nvfp4.allow_flashinfer
|
||||||
@ -952,27 +969,35 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
self.fused_experts: Optional[
|
self.fused_experts: Optional[
|
||||||
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
mk.FusedMoEModularKernel] = None # type: ignore[assignment]
|
||||||
|
|
||||||
def maybe_swap_experts_impl(
|
def maybe_make_prepare_finalize(
|
||||||
self,
|
self,
|
||||||
moe_parallel_config: FusedMoEParallelConfig,
|
moe: FusedMoEConfig,
|
||||||
):
|
) -> Optional[mk.FusedMoEPrepareAndFinalize]:
|
||||||
if not self.allow_flashinfer:
|
if not self.allow_flashinfer:
|
||||||
return
|
return super().maybe_make_prepare_finalize(moe)
|
||||||
self.fused_experts = build_flashinfer_fp4_cutlass_moe_kernel(
|
|
||||||
moe_parallel_config)
|
|
||||||
|
|
||||||
# This method update self.fused_experts
|
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||||
# only prepare_finalize is not None call select_gemm_impl
|
moe,
|
||||||
# so when native cutlass fp4, fused_expert is in fuse_moe.py fused_expert
|
a1_gscale=self.layer.w13_input_scale_quant,
|
||||||
# when it's not called(TP case), we still have 2 kernels to use.
|
)
|
||||||
def select_gemm_impl(self, prepare_finalize,
|
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||||
moe) -> mk.FusedMoEPermuteExpertsUnpermute:
|
return prepare_finalize
|
||||||
|
|
||||||
assert moe is not None and prepare_finalize is not None
|
def select_gemm_impl(
|
||||||
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( # noqa: E501
|
self,
|
||||||
select_nvfp4_gemm_impl)
|
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
return select_nvfp4_gemm_impl(self.allow_flashinfer, moe, logger)
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
|
experts = select_nvfp4_gemm_impl(
|
||||||
|
moe,
|
||||||
|
g1_alphas=self.layer.g1_alphas,
|
||||||
|
g2_alphas=self.layer.g2_alphas,
|
||||||
|
a1_gscale=self.layer.w13_input_scale_quant,
|
||||||
|
a2_gscale=self.layer.w2_input_scale_quant,
|
||||||
|
allow_flashinfer=self.allow_flashinfer,
|
||||||
|
)
|
||||||
|
logger.debug_once("Using %s", experts.__class__.__name__)
|
||||||
|
return experts
|
||||||
|
|
||||||
def uses_weight_scale_2_pattern(self) -> bool:
|
def uses_weight_scale_2_pattern(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -1362,7 +1387,8 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
return torch.ops.vllm.fused_marlin_moe(
|
return torch.ops.vllm.fused_marlin_moe(
|
||||||
@ -1404,21 +1430,28 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
n=layer.w2_weight.shape[2] * 2,
|
n=layer.w2_weight.shape[2] * 2,
|
||||||
k=x.shape[1],
|
k=x.shape[1],
|
||||||
e=layer.w13_weight.shape[0],
|
e=layer.w13_weight.shape[0],
|
||||||
device=x.device,
|
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
else:
|
else:
|
||||||
assert self.allow_flashinfer and \
|
assert self.allow_flashinfer and \
|
||||||
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS
|
||||||
out = flashinfer_fp4_cutlass_moe_forward(
|
|
||||||
self.fused_experts,
|
assert is_valid_flashinfer_cutlass_fused_moe(
|
||||||
layer,
|
x, layer.w13_weight, layer.w2_weight), (
|
||||||
x,
|
"Flashinfer CUTLASS Fused MoE not applicable!")
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
out = self.fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=False, # TODO(shuw): fix later, now output is high prec
|
||||||
activation=activation,
|
activation=activation,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
|
w1_scale=layer.w13_blockscale_swizzled,
|
||||||
|
w2_scale=layer.w2_blockscale_swizzled,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -7,7 +7,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -160,7 +160,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("moe_wna16 only support gptq and awq.")
|
raise ValueError("moe_wna16 only support gptq and awq.")
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return MoeWNA16Method(self)
|
return MoeWNA16Method(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -175,7 +175,12 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_config: MoeWNA16Config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_config: MoeWNA16Config,
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -302,6 +307,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `MoeWNA16Method` yet.")
|
"EPLB not supported for `MoeWNA16Method` yet.")
|
||||||
@ -318,7 +325,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
weight_bits = self.quant_config.weight_bits
|
weight_bits = self.quant_config.weight_bits
|
||||||
has_zp = self.quant_config.has_zp
|
has_zp = self.quant_config.has_zp
|
||||||
|
|||||||
@ -82,7 +82,7 @@ class Mxfp4Config(QuantizationConfig):
|
|||||||
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, moe: FusedMoEConfig):
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
super().__init__()
|
super().__init__(moe)
|
||||||
self.topk_indices_dtype = None
|
self.topk_indices_dtype = None
|
||||||
self.moe = moe
|
self.moe = moe
|
||||||
self.use_marlin = self._should_use_marlin()
|
self.use_marlin = self._should_use_marlin()
|
||||||
|
|||||||
@ -7,7 +7,8 @@ import torch
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
OCP_MX_BLOCK_SIZE)
|
OCP_MX_BLOCK_SIZE)
|
||||||
@ -25,6 +26,9 @@ __all__ = [
|
|||||||
|
|
||||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_moe_method(
|
def get_moe_method(
|
||||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||||
@ -42,17 +46,24 @@ class QuarkMoEMethod(FusedMoEMethodBase):
|
|||||||
input_config = layer_quant_config.get("input_tensors")
|
input_config = layer_quant_config.get("input_tensors")
|
||||||
|
|
||||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
|
||||||
|
module.moe_config)
|
||||||
elif quant_config._is_mx_fp4(weight_config, input_config):
|
elif quant_config._is_mx_fp4(weight_config, input_config):
|
||||||
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
|
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
|
||||||
|
module.moe_config)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||||
|
|
||||||
|
|
||||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||||
|
|
||||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
def __init__(
|
||||||
Any]):
|
self,
|
||||||
|
weight_config: dict[str, Any],
|
||||||
|
input_config: dict[str, Any],
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.weight_quant = weight_config
|
self.weight_quant = weight_config
|
||||||
self.input_quant = input_config
|
self.input_quant = input_config
|
||||||
|
|
||||||
@ -215,6 +226,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||||
@ -231,7 +244,8 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
x,
|
x,
|
||||||
@ -253,8 +267,13 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
|
|
||||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||||
|
|
||||||
def __init__(self, weight_config: dict[str, Any], input_config: dict[str,
|
def __init__(
|
||||||
Any]):
|
self,
|
||||||
|
weight_config: dict[str, Any],
|
||||||
|
input_config: dict[str, Any],
|
||||||
|
moe: FusedMoEConfig,
|
||||||
|
):
|
||||||
|
super().__init__(moe)
|
||||||
self.weight_quant = weight_config
|
self.weight_quant = weight_config
|
||||||
self.input_quant = input_config
|
self.input_quant = input_config
|
||||||
|
|
||||||
@ -369,6 +388,7 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -386,7 +406,8 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
out = fused_experts(
|
out = fused_experts(
|
||||||
x,
|
x,
|
||||||
|
|||||||
@ -10,7 +10,8 @@ import torch.nn.functional as F
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||||
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
@ -76,7 +77,7 @@ class RTNConfig(QuantizationConfig):
|
|||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return RTNLinearMethod(self)
|
return RTNLinearMethod(self)
|
||||||
elif isinstance(layer, FusedMoE):
|
elif isinstance(layer, FusedMoE):
|
||||||
return RTNMoEMethod(self)
|
return RTNMoEMethod(self, layer.moe_config)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -210,7 +211,8 @@ class RTNLinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
class RTNMoEMethod(FusedMoEMethodBase):
|
class RTNMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
def __init__(self, quant_config: RTNConfig):
|
def __init__(self, quant_config: RTNConfig, moe: FusedMoEConfig):
|
||||||
|
super().__init__(moe)
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
@ -289,6 +291,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
|||||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||||
logical_replica_count: Optional[torch.Tensor] = None,
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
assert self.fused_experts is None
|
||||||
|
|
||||||
if enable_eplb:
|
if enable_eplb:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"EPLB not supported for `RTNMoEMethod` yet.")
|
"EPLB not supported for `RTNMoEMethod` yet.")
|
||||||
@ -305,7 +309,8 @@ class RTNMoEMethod(FusedMoEMethodBase):
|
|||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
scoring_func=scoring_func,
|
scoring_func=scoring_func,
|
||||||
e_score_correction_bias=e_score_correction_bias)
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
|
indices_type=self.topk_indices_dtype)
|
||||||
|
|
||||||
weight_bits = self.quant_config.weight_bits
|
weight_bits = self.quant_config.weight_bits
|
||||||
group_size = self.quant_config.group_size
|
group_size = self.quant_config.group_size
|
||||||
|
|||||||
@ -3,33 +3,30 @@
|
|||||||
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
"""Utility helpers for NVFP4 + FlashInfer fused-MoE path"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
from vllm.logger import init_logger
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEParallelConfig
|
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||||
FlashInferExperts, is_valid_flashinfer_cutlass_fused_moe)
|
FlashInferExperts)
|
||||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||||
FlashInferCutlassMoEPrepareAndFinalize)
|
FlashInferCutlassMoEPrepareAndFinalize)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"is_flashinfer_fp4_cutlass_moe_available",
|
"is_flashinfer_fp4_cutlass_moe_available",
|
||||||
"reorder_w1w3_to_w3w1",
|
"reorder_w1w3_to_w3w1",
|
||||||
"build_flashinfer_fp4_cutlass_moe_kernel",
|
"build_flashinfer_fp4_cutlass_moe_prepare_finalize",
|
||||||
"flashinfer_fp4_cutlass_moe_forward",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
def is_flashinfer_fp4_cutlass_moe_available() -> bool:
|
||||||
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
"""Return ``True`` when FlashInfer CUTLASS NV-FP4 kernels can be used."""
|
||||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4 and current_platform.is_cuda()
|
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||||
|
and has_flashinfer_cutlass_fused_moe()
|
||||||
|
and current_platform.is_cuda()
|
||||||
and current_platform.is_device_capability(100))
|
and current_platform.is_device_capability(100))
|
||||||
|
|
||||||
|
|
||||||
@ -49,105 +46,33 @@ def reorder_w1w3_to_w3w1(weight: torch.Tensor,
|
|||||||
dim=dim).contiguous())
|
dim=dim).contiguous())
|
||||||
|
|
||||||
|
|
||||||
def build_flashinfer_fp4_cutlass_moe_kernel(
|
def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||||
moe_parallel_config: FusedMoEParallelConfig, ) -> mk.FusedMoEModularKernel:
|
moe: FusedMoEConfig,
|
||||||
"""Create *and return* a FlashInfer CUTLASS fused-MoE modular kernel"""
|
a1_gscale: torch.Tensor,
|
||||||
experts = FlashInferExperts(
|
) -> mk.FusedMoEPrepareAndFinalize:
|
||||||
use_nvfp4_w4a4=True,
|
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||||
use_dp=moe_parallel_config.dp_size > 1,
|
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||||
ep_rank=moe_parallel_config.ep_rank,
|
return FlashInferCutlassMoEPrepareAndFinalize(use_dp, a1_gscale=a1_gscale)
|
||||||
ep_size=moe_parallel_config.ep_size,
|
|
||||||
tp_rank=moe_parallel_config.tp_rank,
|
|
||||||
tp_size=moe_parallel_config.tp_size,
|
|
||||||
)
|
|
||||||
logger.debug_once("FlashInferExperts (util)")
|
|
||||||
return mk.FusedMoEModularKernel(
|
|
||||||
FlashInferCutlassMoEPrepareAndFinalize(quant_dtype=torch.uint8),
|
|
||||||
experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_fp4_cutlass_moe_forward(
|
|
||||||
fused_experts: mk.FusedMoEModularKernel,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
topk_weights: torch.Tensor,
|
|
||||||
topk_ids: torch.Tensor,
|
|
||||||
activation: str,
|
|
||||||
global_num_experts: int,
|
|
||||||
expert_map: Optional[torch.Tensor],
|
|
||||||
apply_router_weight_on_input: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Common forward wrapper for FlashInfer NV-FP4 fused-MoE"""
|
|
||||||
|
|
||||||
assert is_valid_flashinfer_cutlass_fused_moe(
|
|
||||||
x, layer.w13_weight,
|
|
||||||
layer.w2_weight), ("FlashInfer CUTLASS fused-MoE not applicable!")
|
|
||||||
|
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
|
||||||
a2_gscale = layer.w2_input_scale_quant
|
|
||||||
|
|
||||||
extra_expert_args = {
|
|
||||||
"g1_alphas": layer.g1_alphas,
|
|
||||||
"g2_alphas": layer.g2_alphas,
|
|
||||||
# Avoid confusion with a1_scale and a2_scale
|
|
||||||
# where are batch size related.
|
|
||||||
"a1_gscale": a1_gscale,
|
|
||||||
"a2_gscale": a2_gscale,
|
|
||||||
"out_dtype": x.dtype,
|
|
||||||
}
|
|
||||||
extra_prepare_args = {
|
|
||||||
"use_dp": layer.dp_size > 1,
|
|
||||||
"local_tokens": x.shape[0],
|
|
||||||
"a1_gscale": a1_gscale,
|
|
||||||
}
|
|
||||||
extra_finalize_args = {
|
|
||||||
"use_dp": layer.dp_size > 1,
|
|
||||||
"local_tokens": x.shape[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
return fused_experts(
|
|
||||||
hidden_states=x,
|
|
||||||
w1=layer.w13_weight,
|
|
||||||
w2=layer.w2_weight,
|
|
||||||
topk_weights=topk_weights,
|
|
||||||
topk_ids=topk_ids,
|
|
||||||
inplace=False, # TODO(shuw): fix later, now output is high prec
|
|
||||||
activation=activation,
|
|
||||||
global_num_experts=global_num_experts,
|
|
||||||
expert_map=expert_map,
|
|
||||||
w1_scale=layer.w13_blockscale_swizzled,
|
|
||||||
w2_scale=layer.w2_blockscale_swizzled,
|
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
||||||
extra_expert_args=extra_expert_args,
|
|
||||||
extra_prepare_args=extra_prepare_args,
|
|
||||||
extra_finalize_args=extra_finalize_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_nvfp4_gemm_impl(
|
def select_nvfp4_gemm_impl(
|
||||||
allow_flashinfer: bool,
|
moe: FusedMoEConfig,
|
||||||
moe, # FusedMoEConfig
|
g1_alphas: torch.Tensor,
|
||||||
logger):
|
g2_alphas: torch.Tensor,
|
||||||
|
a1_gscale: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
|
allow_flashinfer: bool,
|
||||||
|
) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||||
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
"""Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers"""
|
||||||
|
|
||||||
# lazy import
|
|
||||||
from vllm.distributed import get_ep_group
|
|
||||||
|
|
||||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
|
||||||
assert all2all_manager is not None
|
|
||||||
|
|
||||||
if allow_flashinfer:
|
if allow_flashinfer:
|
||||||
flashinfer_backend = envs.VLLM_FLASHINFER_MOE_BACKEND
|
|
||||||
if flashinfer_backend != "throughput":
|
|
||||||
raise ValueError(
|
|
||||||
f"Only throughput backend is supported for FlashInferExperts, "
|
|
||||||
f"but got {flashinfer_backend}.")
|
|
||||||
logger.debug_once(
|
|
||||||
"Initializing FlashInferExperts with throughput backend.")
|
|
||||||
return FlashInferExperts(
|
return FlashInferExperts(
|
||||||
use_nvfp4_w4a4=True,
|
g1_alphas=g1_alphas,
|
||||||
use_dp=moe.moe_parallel_config.dp_size > 1,
|
g2_alphas=g2_alphas,
|
||||||
|
a1_gscale=a1_gscale,
|
||||||
|
a2_gscale=a2_gscale,
|
||||||
|
out_dtype=moe.in_dtype,
|
||||||
|
quant_dtype="nvfp4",
|
||||||
ep_rank=moe.moe_parallel_config.ep_rank,
|
ep_rank=moe.moe_parallel_config.ep_rank,
|
||||||
ep_size=moe.moe_parallel_config.ep_size,
|
ep_size=moe.moe_parallel_config.ep_size,
|
||||||
tp_rank=moe.moe_parallel_config.tp_rank,
|
tp_rank=moe.moe_parallel_config.tp_rank,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user