From eaa2c0e00991d24e4a0ba19d3b9bb1f45628a9b5 Mon Sep 17 00:00:00 2001 From: Juan Torrente <79209573+jnt0rrente@users.noreply.github.com> Date: Fri, 23 Aug 2024 22:40:03 +0000 Subject: [PATCH] Improves error message when passing wrong tensor type to torch.nn.functional.one_hot (#134209) The function expects a Tensor of type LongTensor. It currently throws the following error: "one_hot is only applicable to index tensor." which, imo, does not provide the user with enough information on what the problem is. PR simply adds extra information to the error message on this specific scenario. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134209 Approved by: https://github.com/mikaylagawarecki --- aten/src/ATen/native/Onehot.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index ffd19b2e93a..fcbe7fd1ddc 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -15,7 +15,7 @@ namespace at::native { Tensor one_hot(const Tensor &self, int64_t num_classes) { - TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); + TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor."); // using meta bit test to catch Fake Tensor as well until __torch_function__ if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||