Add tensor.is_privateuseone (#113421)

We found a scenario where ``tensor.device().is_privateuseone()`` is used to determine whether a tensor is privateuse1 but fails.
In the code of ``Autograd``, for example:
```
::std::tuple<at::Tensor,at::Tensor,at::Tensor> native_batch_norm(c10::DispatchKeySet ks, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps) {
  auto& input_ = unpack(input, "input", 0);
  [[maybe_unused]] auto _any_requires_grad = compute_requires_grad( input, weight, bias );

  [[maybe_unused]] auto _any_has_forward_grad_result0 = (isFwGradDefined(input) || isFwGradDefined(weight) || isFwGradDefined(bias));
  check_no_requires_grad(running_mean, "running_mean", "native_batch_norm");
  check_no_requires_grad(running_var, "running_var", "native_batch_norm");
  std::shared_ptr<NativeBatchNormBackward0> grad_fn;
  if (_any_requires_grad) {
    grad_fn = std::shared_ptr<NativeBatchNormBackward0>(new NativeBatchNormBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( input, weight, bias ));
    grad_fn->eps = eps;
    grad_fn->input_ = SavedVariable(input, false);
    grad_fn->running_mean_ = SavedVariable(running_mean, false);
    grad_fn->running_var_ = SavedVariable(running_var, false);
    grad_fn->training = training;
    grad_fn->weight_ = SavedVariable(weight, false);
  }
  ...
}
```
When ``weight`` is ``None``, an empty tensor is automatically generated and will be transferred to the backward calculation:
c7e12c7427/torch/csrc/autograd/saved_variable.cpp (L121-L128)
At the beginning of the backward calculation in our scenario, we need to determine whether the input tensor is ``PrivateUse1`` . However, if we use ``tensor.device().is_privateuseone()``, we will get an error ``"tensor does not have a device"``:
c7e12c7427/c10/core/TensorImpl.h (L1223-L1235)
I think this part of the code can be optimized, what do you think?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113421
Approved by: https://github.com/albanD
This commit is contained in:
Chen_Liqing 2023-11-13 01:51:27 +00:00 committed by PyTorch MergeBot
parent 7afb503e3c
commit b910d9eaa6
4 changed files with 17 additions and 3 deletions

View file

@ -64,8 +64,7 @@ inline bool is_autocast_eligible(
case c10::DeviceType::XLA:
return tensor.is_xla() && tensor.is_floating_point();
case c10::DeviceType::PrivateUse1:
return tensor.device().type() == c10::DeviceType::PrivateUse1 &&
tensor.is_floating_point();
return tensor.is_privateuseone() && tensor.is_floating_point();
default:
return false;
}

View file

@ -476,6 +476,12 @@ class TORCH_API TensorBase {
return impl_->is_ve();
}
/// Returns if a `Tensor` has PrivateUse1 backend.
bool is_privateuseone() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_privateuseone();
}
/// Returns if a `Tensor` has sparse backend.
bool is_sparse() const {
// NB: this is not a native function to avoid dispatching overhead.

View file

@ -42,7 +42,7 @@ Tensor make_feature_noise(const Tensor& input) {
}
bool is_fused_kernel_acceptable(const Tensor& input, double p) {
return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.device().is_privateuseone()) && p > 0 && p < 1 && input.sym_numel() > 0;
return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.sym_numel() > 0;
}
// NB: sure, we could have used different overloads here, but I would feel insecure

View file

@ -1156,6 +1156,15 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return device_opt_.has_value() && device_opt_->type() == kVE;
}
bool is_privateuseone() const {
// NB: This method is not virtual and avoid dispatches for performance
// reasons.
if (C10_UNLIKELY(device_policy_)) {
return device_custom().is_privateuseone();
}
return device_opt_.has_value() && device_opt_->type() == kPrivateUse1;
}
bool is_mkldnn() const {
return key_set_.has_all(c10::mkldnn_ks);
}