diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index f4b6b19f4b232..9d05d910dd81f 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "page_table must be a 32-bit integer tensor"); auto in_dtype = q_nope.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); if (in_dtype == at::ScalarType::Half) { diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index f62d08c17c6d8..c83d72751a55c 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, params.conv_states_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); @@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x, params.conv_state_indices_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 0c9df925bdbf6..785d316025eca 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ); - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index b51033c9b72c9..190d66f318a83 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -561,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a( TORCH_CHECK(output_scale.size(1) * 4 == padded_k); auto in_dtype = input.dtype(); - at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(input.get_device()); if (in_dtype == at::ScalarType::Half) { @@ -579,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a( } else { TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); } -} \ No newline at end of file +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index fef74111624f0..d32911357a953 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, auto input_sf_ptr = static_cast(input_sf.data_ptr()); auto sf_out = static_cast(output_sf.data_ptr()); auto output_ptr = static_cast(output.data_ptr()); - at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); // We don't support e8m0 scales at this moment. diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 97c0e0da7b1fb..7572a7eb3122d 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, B_sf.sizes()[1], ")"); auto out_dtype = D.dtype(); - at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); if (out_dtype == at::ScalarType::Half) {