Improves error message when passing wrong tensor type to torch.nn.functional.one_hot (#134209)

The function expects a Tensor of type LongTensor. It currently throws the following error: "one_hot is only applicable to index tensor." which, imo, does not provide the user with enough information on what the problem is.

PR simply adds extra information to the error message on this specific scenario.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134209
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
Juan Torrente 2024-08-23 22:40:03 +00:00 committed by PyTorch MergeBot
parent 09a82f3d24
commit eaa2c0e009

View file

@ -15,7 +15,7 @@
namespace at::native {
Tensor one_hot(const Tensor &self, int64_t num_classes) {
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor.");
// using meta bit test to catch Fake Tensor as well until __torch_function__
if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||