mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix auto exponent issue for torch.pow (#47024)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47024 Fixes https://github.com/pytorch/pytorch/issues/46936 Stack from [ghstack](https://github.com/ezyang/ghstack): * **#47024 Fix auto exponent issue for torch.pow** Test Plan: Imported from OSS Reviewed By: malfet Differential Revision: D24698027 Pulled By: anjali411 fbshipit-source-id: f23fdb65c925166243593036e08214c4f041a63d
This commit is contained in:
parent
d293413b3e
commit
8ef7ccd669
5 changed files with 36 additions and 15 deletions
|
|
@ -31,11 +31,9 @@ Tensor& pow_out(Tensor& result, const Tensor& base, Scalar exp) {
|
|||
"result type ", common_dtype, "can't be cast to the desired output type ",
|
||||
result.scalar_type());
|
||||
|
||||
auto exponent = (exp.isComplex()) ? exp.toComplexDouble() : exp.toDouble();
|
||||
|
||||
if (exponent == 0.0) {
|
||||
if (exp.equal(0.0)) {
|
||||
result.resize_as_(base).fill_(1);
|
||||
} else if (exponent == 1.0) {
|
||||
} else if (exp.equal(1.0)) {
|
||||
result.resize_as_(base).copy_(base);
|
||||
} else {
|
||||
auto iter = TensorIterator::unary_op(result, base.to(common_dtype));
|
||||
|
|
|
|||
|
|
@ -21,4 +21,14 @@ Scalar Scalar::conj() const {
|
|||
}
|
||||
}
|
||||
|
||||
Scalar Scalar::log() const {
|
||||
if (isComplex()) {
|
||||
return std::log(v.z);
|
||||
} else if (isFloatingPoint()) {
|
||||
return std::log(v.d);
|
||||
} else {
|
||||
return std::log(v.i);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -88,6 +88,19 @@ class C10_API Scalar {
|
|||
|
||||
Scalar operator-() const;
|
||||
Scalar conj() const;
|
||||
Scalar log() const;
|
||||
|
||||
template<typename T>
|
||||
bool equal(T num) const {
|
||||
if (isComplex()) {
|
||||
return v.z == num;
|
||||
} else if (isFloatingPoint()) {
|
||||
return v.d == num;
|
||||
} else {
|
||||
return v.i == num;
|
||||
}
|
||||
}
|
||||
|
||||
ScalarType type() const {
|
||||
if (isComplex()) {
|
||||
return ScalarType::ComplexDouble;
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ TEST(AutogradAPITests, AnomalyMode) {
|
|||
auto y = x.pow(1.5);
|
||||
auto gr =
|
||||
grad({y}, {x}, {}, /*retain_graph=*/true, /*create_backward=*/true);
|
||||
ASSERT_THROWS_WITH(grad({gr[0]}, {x});, "returned nan");
|
||||
ASSERT_THROWS_WITH(grad({gr[0]}, {x}, {torch::tensor({0.0})});, "returned nan");
|
||||
auto msgs = warnings.messages();
|
||||
ASSERT_EQ(msgs.size(), 2);
|
||||
ASSERT_TRUE(
|
||||
|
|
|
|||
|
|
@ -191,12 +191,12 @@ Tensor norm_backward(Tensor grad, const Tensor & self, const optional<Scalar> &
|
|||
return norm_backward(grad, self, p_, norm);
|
||||
}
|
||||
|
||||
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent_) {
|
||||
auto exponent = (exponent_.isComplex()) ? exponent_.toComplexDouble() : exponent_.toDouble();
|
||||
if (exponent == 0.0) {
|
||||
Tensor pow_backward(Tensor grad, const Tensor & self, const Scalar & exponent) {
|
||||
if (exponent.equal(0.0)) {
|
||||
return at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
} else {
|
||||
auto out = grad * (exponent * self.pow(exponent - 1)).conj();
|
||||
auto grad_lambda = [&](auto exp) { return grad * (exp * self.pow(exp - 1)).conj(); };
|
||||
Tensor out = (exponent.isComplex()) ? grad_lambda(exponent.toComplexDouble()) : grad_lambda(exponent.toDouble());
|
||||
return handle_r_to_c(self, out);
|
||||
}
|
||||
}
|
||||
|
|
@ -229,9 +229,9 @@ Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& expo
|
|||
}
|
||||
|
||||
Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) {
|
||||
auto base_ = base.isComplex() ? base.toComplexDouble() : base.toDouble();
|
||||
auto grad_lambda = [](auto a, auto b) { return (a * std::log(b)).conj(); };
|
||||
if (base_ == 0.0) {
|
||||
auto grad_lambda = [](Tensor a, Scalar b) { return (a * b.log()).conj(); };
|
||||
|
||||
if (base.equal(0.0)) {
|
||||
auto cond = [](auto exp) {
|
||||
if (exp.is_complex()) {
|
||||
return at::logical_and(at::imag(exp) == 0, at::real(exp) >= 0);
|
||||
|
|
@ -241,10 +241,10 @@ Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exp
|
|||
};
|
||||
auto out = grad * at::where(cond(exponent),
|
||||
at::zeros({}, grad.options()),
|
||||
grad_lambda(result, base_));
|
||||
grad_lambda(result, base));
|
||||
return handle_r_to_c(exponent, out);
|
||||
} else {
|
||||
auto out = grad * grad_lambda(result, base_);
|
||||
auto out = grad * grad_lambda(result, base);
|
||||
return handle_r_to_c(exponent, out);
|
||||
}
|
||||
}
|
||||
|
|
@ -2747,7 +2747,7 @@ Tensor constant_pad_nd_backward(const Tensor& grad, IntArrayRef pad) {
|
|||
}
|
||||
|
||||
Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indices, int64_t padding_idx) {
|
||||
// since first backward takes care of scaling by frequency,
|
||||
// since first backward takes care of scaling by frequency,
|
||||
// we don't need to worry about it here.
|
||||
auto gg_weight = grad.index_select(0, indices.reshape(-1));
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue