[model][refactor] remove cuda hard code in models and layers (#13658)

This commit is contained in:
Mengqing Cao 2025-02-24 22:10:14 +08:00 committed by GitHub
parent 437b76ff59
commit 23eca9cf68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 29 additions and 14 deletions

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config) fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -238,7 +239,7 @@ def fused_marlin_moe(
max_workspace_size = (max(2 * N, K) // 64) * 16 max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size, workspace = torch.zeros(max_workspace_size,
dtype=torch.int, dtype=torch.int,
device="cuda", device=current_platform.device_type,
requires_grad=False) requires_grad=False)
if has_no_zp: if has_no_zp:

View File

@ -30,6 +30,7 @@ import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@ -650,9 +651,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style, dtype) is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange( pos_freqs = self.base**(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / torch.arange(0,
self.rotary_dim) self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor, t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda", device=current_platform.device_type,
dtype=torch.float32) dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale) cos = (freqs.cos() * self.mscale)

View File

@ -7,6 +7,8 @@ import torch
import torch.jit import torch.jit
import torch.nn as nn import torch.nn as nn
from vllm.platforms import current_platform
class SpecDecodeBaseSampler(nn.Module): class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification """Base class for samplers used for Speculative Decoding verification
@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module):
def init_gpu_tensors(self, device: Union[int, str]) -> None: def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None assert self.num_accepted_tokens is None
if isinstance(device, int): if isinstance(device, int):
device = f"cuda:{device}" device = f"{current_platform.device_type}:{device}"
elif not isinstance(device, str): elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}") raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0, self.num_accepted_tokens = torch.tensor(0,

View File

@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if param_name + "." in k: if param_name + "." in k:
quant_state[k] = temp_state_dict[k] quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda") return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights # Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state # pre quantized weights would have a quant_state

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig from vllm.transformers_utils.configs.arctic import ArcticConfig
@ -138,13 +139,13 @@ class ArcticMoE(nn.Module):
torch.empty(self.num_experts, torch.empty(self.num_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
self.w2s = nn.Parameter( self.w2s = nn.Parameter(
torch.empty(self.num_experts, torch.empty(self.num_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
set_weight_attrs(self.ws, { set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,

View File

@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP from .interfaces import SupportsLoRA, SupportsPP
@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module):
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
self.w2s = nn.Parameter( self.w2s = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
device="cuda", device=current_platform.device_type,
dtype=self.params_dtype)) dtype=self.params_dtype))
set_weight_attrs(self.ws, { set_weight_attrs(self.ws, {

View File

@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement) BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .idefics2_vision_model import Idefics2VisionTransformer from .idefics2_vision_model import Idefics2VisionTransformer
@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,
@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,
@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype()) return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding( def get_vision_embedding(
self, self,