[model][refactor] remove cuda hard code in models and layers (#13658)

This commit is contained in:
Mengqing Cao 2025-02-24 22:10:14 +08:00 committed by GitHub
parent 437b76ff59
commit 23eca9cf68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 29 additions and 14 deletions

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size, try_get_optimal_moe_config)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import direct_register_custom_op
@ -238,7 +239,7 @@ def fused_marlin_moe(
max_workspace_size = (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
device=current_platform.device_type,
requires_grad=False)
if has_no_zp:

View File

@ -30,6 +30,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@ -650,9 +651,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style, dtype)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
pos_freqs = self.base**(
torch.arange(0,
self.rotary_dim,
2,
dtype=torch.float,
device=current_platform.device_type) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
@ -670,7 +675,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
device=current_platform.device_type,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)

View File

@ -7,6 +7,8 @@ import torch
import torch.jit
import torch.nn as nn
from vllm.platforms import current_platform
class SpecDecodeBaseSampler(nn.Module):
"""Base class for samplers used for Speculative Decoding verification
@ -35,7 +37,7 @@ class SpecDecodeBaseSampler(nn.Module):
def init_gpu_tensors(self, device: Union[int, str]) -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"cuda:{device}"
device = f"{current_platform.device_type}:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
self.num_accepted_tokens = torch.tensor(0,

View File

@ -914,7 +914,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state, device="cuda")
return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state

View File

@ -30,6 +30,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.arctic import ArcticConfig
@ -138,13 +139,13 @@ class ArcticMoE(nn.Module):
torch.empty(self.num_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,

View File

@ -51,6 +51,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
@ -98,13 +99,13 @@ class MiniCPMMoE(nn.Module):
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
device=current_platform.device_type,
dtype=self.params_dtype))
set_weight_attrs(self.ws, {

View File

@ -59,6 +59,7 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ImageSize,
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from .idefics2_vision_model import Idefics2VisionTransformer
@ -1184,7 +1185,8 @@ class MiniCPMV2_0(MiniCPMVBaseModel):
quant_config=quant_config,
prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@ -1266,7 +1268,8 @@ class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config,
prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,
@ -1360,7 +1363,8 @@ class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA):
quant_config=quant_config,
prefix=prefix)
return resampler.to(device="cuda", dtype=torch.get_default_dtype())
return resampler.to(device=current_platform.device_type,
dtype=torch.get_default_dtype())
def get_vision_embedding(
self,