mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Improves error message when passing wrong tensor type to torch.nn.functional.one_hot (#134209)
The function expects a Tensor of type LongTensor. It currently throws the following error: "one_hot is only applicable to index tensor." which, imo, does not provide the user with enough information on what the problem is. PR simply adds extra information to the error message on this specific scenario. Pull Request resolved: https://github.com/pytorch/pytorch/pull/134209 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
parent
09a82f3d24
commit
eaa2c0e009
1 changed files with 1 additions and 1 deletions
|
|
@ -15,7 +15,7 @@
|
|||
namespace at::native {
|
||||
|
||||
Tensor one_hot(const Tensor &self, int64_t num_classes) {
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor.");
|
||||
TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor.");
|
||||
|
||||
// using meta bit test to catch Fake Tensor as well until __torch_function__
|
||||
if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||
|
||||
|
|
|
|||
Loading…
Reference in a new issue