Fix auto exponent issue for torch.pow (#47024)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47024

Fixes https://github.com/pytorch/pytorch/issues/46936

Stack from [ghstack](https://github.com/ezyang/ghstack):
* **#47024 Fix auto exponent issue for torch.pow**

Test Plan: Imported from OSS

Reviewed By: malfet

Differential Revision: D24698027

Pulled By: anjali411

fbshipit-source-id: f23fdb65c925166243593036e08214c4f041a63d
This commit is contained in:
anjali411 2020-11-14 22:43:59 -08:00 committed by Facebook GitHub Bot
parent d293413b3e
commit 8ef7ccd669
5 changed files with 36 additions and 15 deletions

View file

@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
"result type ", common_dtype, "can't be cast to the desired output type ",
result.scalar_type());
auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble();
if (exponent == 0.0) {
if (exp.equal(0.0)) {
result.resize_as_(base).fill_(1);
} else if (exponent == 1.0) {
} else if (exp.equal(1.0)) {
result.resize_as_(base).copy_(base);
} else {
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));

View file

@ -21,4 +21,14 @@ Scalar Scalar::conj() const {
}
}
Scalar Scalar::log() const {
if (isComplex()) {
return std::log(v.z);
} else if (isFloatingPoint()) {
return std::log(v.d);
} else {
return std::log(v.i);
}
}
} // namespace c10

View file

@ -88,6 +88,19 @@ class C10_API Scalar {
Scalar operator-() const;
Scalar conj() const;
Scalar log() const;
template<typename T>
bool equal(T num) const {
if (isComplex()) {
return v.z == num;
} else if (isFloatingPoint()) {
return v.d == num;
} else {
return v.i == num;
}
}
ScalarType type() const {
if (isComplex()) {
return ScalarType::ComplexDouble;

View file

@ -175,7 +175,7 @@ TEST(AutogradAPITests, AnomalyMode) {
auto y = x.pow(1.5);
auto gr =
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan");
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan");
auto msgs = warnings.messages();
ASSERT_EQ(msgs.size(), 2);
ASSERT_TRUE(

View file

@ -191,12 +191,12 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> &
return norm_backward(grad, self, p_, norm);
}
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble();
if (exponent == 0.0) {
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {
if (exponent.equal(0.0)) {
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
auto out = grad * (exponent * self.pow(exponent - 1)).conj();
auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); };
Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble());
return handle_r_to_c(self, out);
}
}
@ -229,9 +229,9 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& expo
}
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble();
auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); };
if (base_ == 0.0) {
auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); };
if (base.equal(0.0)) {
auto cond = [](auto exp) {
if (exp.is_complex()) {
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
@ -241,10 +241,10 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
};
auto out = grad * at::where(cond(exponent),
at::zeros({}, grad.options()),
grad_lambda(result, base_));
grad_lambda(result, base));
return handle_r_to_c(exponent, out);
} else {
auto out = grad * grad_lambda(result, base_);
auto out = grad * grad_lambda(result, base);
return handle_r_to_c(exponent, out);
}
}
@ -2747,7 +2747,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
}
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
// since first backward takes care of scaling by frequency,
// since first backward takes care of scaling by frequency,
// we don't need to worry about it here.
auto gg_weight = grad.index_select(0, indices.reshape(-1));