diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index ffd19b2e93a..fcbe7fd1ddc 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -15,7 +15,7 @@ namespace at::native { Tensor one_hot(const Tensor &self, int64_t num_classes) { - TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor."); + TORCH_CHECK(self.dtype() == kLong, "one_hot is only applicable to index tensor of type LongTensor."); // using meta bit test to catch Fake Tensor as well until __torch_function__ if (self.key_set().has_all(DispatchKeySet(BackendComponent::MetaBit)) ||