From 8ef7ccd669a36f49eb052d12a71df9b7836cafa5 Mon Sep 17 00:00:00 2001 From: anjali411 Date: Sat, 14 Nov 2020 22:43:59 -0800 Subject: [PATCH] Fix auto exponent issue for torch.pow (#47024) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47024 Fixes https://github.com/pytorch/pytorch/issues/46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D24698027 Pulled By: anjali411 fbshipit-source-id: f23fdb65c925166243593036e08214c4f041a63d --- aten/src/ATen/native/Pow.cpp | 6 ++---- c10/core/Scalar.cpp | 10 ++++++++++ c10/core/Scalar.h | 13 +++++++++++++ test/cpp/api/autograd.cpp | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 20 ++++++++++---------- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/aten/src/ATen/native/Pow.cpp b/aten/src/ATen/native/Pow.cpp index ca5d1848a4b..573832b0523 100644 --- a/aten/src/ATen/native/Pow.cpp +++ b/aten/src/ATen/native/Pow.cpp @@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) { "result type ", common_dtype, "can't be cast to the desired output type ", result.scalar_type()); - auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble(); - - if (exponent == 0.0) { + if (exp.equal(0.0)) { result.resize_as_(base).fill_(1); - } else if (exponent == 1.0) { + } else if (exp.equal(1.0)) { result.resize_as_(base).copy_(base); } else { auto iter = TensorIterator::unary_op(result, base.to(common_dtype)); diff --git a/c10/core/Scalar.cpp b/c10/core/Scalar.cpp index 35aa5d60f00..212c41d5b19 100644 --- a/c10/core/Scalar.cpp +++ b/c10/core/Scalar.cpp @@ -21,4 +21,14 @@ Scalar Scalar::conj() const { } } +Scalar Scalar::log() const { + if (isComplex()) { + return std::log(v.z); + } else if (isFloatingPoint()) { + return std::log(v.d); + } else { + return std::log(v.i); + } +} + } // namespace c10 diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h index 6151f6d2b15..04f70abf1f2 100644 --- a/c10/core/Scalar.h +++ b/c10/core/Scalar.h @@ -88,6 +88,19 @@ class C10_API Scalar { Scalar operator-() const; Scalar conj() const; + Scalar log() const; + + template + bool equal(T num) const { + if (isComplex()) { + return v.z == num; + } else if (isFloatingPoint()) { + return v.d == num; + } else { + return v.i == num; + } + } + ScalarType type() const { if (isComplex()) { return ScalarType::ComplexDouble; diff --git a/test/cpp/api/autograd.cpp b/test/cpp/api/autograd.cpp index 81e530d0dbe..635956d5407 100644 --- a/test/cpp/api/autograd.cpp +++ b/test/cpp/api/autograd.cpp @@ -175,7 +175,7 @@ TEST(AutogradAPITests, AnomalyMode) { auto y = x.pow(1.5); auto gr = grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true); - ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan"); + ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan"); auto msgs = warnings.messages(); ASSERT_EQ(msgs.size(), 2); ASSERT_TRUE( diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 674c5bbbe60..968c1ad58ed 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -191,12 +191,12 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const optional & return norm_backward(grad, self, p_, norm); } -Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) { - auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble(); - if (exponent == 0.0) { +Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) { + if (exponent.equal(0.0)) { return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { - auto out = grad * (exponent * self.pow(exponent - 1)).conj(); + auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); }; + Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble()); return handle_r_to_c(self, out); } } @@ -229,9 +229,9 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& expo } Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { - auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble(); - auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); }; - if (base_ == 0.0) { + auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); }; + + if (base.equal(0.0)) { auto cond = [](auto exp) { if (exp.is_complex()) { return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0); @@ -241,10 +241,10 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp }; auto out = grad * at::where(cond(exponent), at::zeros({}, grad.options()), - grad_lambda(result, base_)); + grad_lambda(result, base)); return handle_r_to_c(exponent, out); } else { - auto out = grad * grad_lambda(result, base_); + auto out = grad * grad_lambda(result, base); return handle_r_to_c(exponent, out); } } @@ -2747,7 +2747,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) { } Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) { - // since first backward takes care of scaling by frequency, + // since first backward takes care of scaling by frequency, // we don't need to worry about it here. auto gg_weight = grad.index_select(0, indices.reshape(-1));