diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index a1b687ce448..b3f2fcd511f 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -64,8 +64,7 @@ inline bool is_autocast_eligible( case c10::DeviceType::XLA: return tensor.is_xla() && tensor.is_floating_point(); case c10::DeviceType::PrivateUse1: - return tensor.device().type() == c10::DeviceType::PrivateUse1 && - tensor.is_floating_point(); + return tensor.is_privateuseone() && tensor.is_floating_point(); default: return false; } diff --git a/aten/src/ATen/core/TensorBase.h b/aten/src/ATen/core/TensorBase.h index 887dbd9d7ad..90a90e60c08 100644 --- a/aten/src/ATen/core/TensorBase.h +++ b/aten/src/ATen/core/TensorBase.h @@ -476,6 +476,12 @@ class TORCH_API TensorBase { return impl_->is_ve(); } + /// Returns if a `Tensor` has PrivateUse1 backend. + bool is_privateuseone() const { + // NB: this is not a native function to avoid dispatching overhead. + return impl_->is_privateuseone(); + } + /// Returns if a `Tensor` has sparse backend. bool is_sparse() const { // NB: this is not a native function to avoid dispatching overhead. diff --git a/aten/src/ATen/native/Dropout.cpp b/aten/src/ATen/native/Dropout.cpp index 20a11645de1..8bad0710218 100644 --- a/aten/src/ATen/native/Dropout.cpp +++ b/aten/src/ATen/native/Dropout.cpp @@ -42,7 +42,7 @@ Tensor make_feature_noise(const Tensor& input) { } bool is_fused_kernel_acceptable(const Tensor& input, double p) { - return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.device().is_privateuseone()) && p > 0 && p < 1 && input.sym_numel() > 0; + return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.sym_numel() > 0; } // NB: sure, we could have used different overloads here, but I would feel insecure diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index b536a7a9279..c4cd54dcaec 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -1156,6 +1156,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { return device_opt_.has_value() && device_opt_->type() == kVE; } + bool is_privateuseone() const { + // NB: This method is not virtual and avoid dispatches for performance + // reasons. + if (C10_UNLIKELY(device_policy_)) { + return device_custom().is_privateuseone(); + } + return device_opt_.has_value() && device_opt_->type() == kPrivateUse1; + } + bool is_mkldnn() const { return key_set_.has_all(c10::mkldnn_ks); }