pytorch/test/cpp/api/autograd.cpp

624 lines
18 KiB
C++
Raw Normal View History

Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
#include <gtest/gtest.h>
#include <torch/torch.h>
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
#include <test/cpp/api/support.h>
using namespace torch::autograd;
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
#define EXPECT_VARIABLE_EQ(a,b) EXPECT_TRUE(torch::allclose((a),(b)))
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
std::string graph_desc(std::shared_ptr<Node> node) {
if (!node) {
return "None";
}
auto result = node->name() + "(";
auto next_edges = node->next_edges();
for(auto& edge : next_edges) {
result += graph_desc(edge.function);
}
return result+")";
}
Variable simple_fn(const Variable& x, const Variable& y) {
return x + 2 * y + x * y;
}
TEST(AutogradAPITests, BackwardSimpleTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res.sum()}, {});
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({2, 2})*2);
}
TEST(AutogradAPITests, BackwardTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
backward({res}, {torch::ones({2, 2})}, {}, true);
backward({res}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(x.grad(), 2* (y + torch::ones({2, 2})));
ASSERT_VARIABLE_EQ(y.grad(), 2 * (x + torch::ones({2, 2})*2));
}
TEST(AutogradAPITests, GradSimpleTest) {
// basic grad
Variable x = torch::randn({2,2}, torch::requires_grad());
Variable y = torch::randn({2,2}, torch::requires_grad());
auto res = simple_fn(x, y);
auto grad_res = grad({res}, {x, y}, {torch::ones({2, 2})});
ASSERT_VARIABLE_EQ(grad_res[0], y + torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(grad_res[1], x + torch::ones({2, 2}) * 2);
}
TEST(AutogradAPITests, GradTest) {
Variable x = torch::randn({2, 2}, torch::requires_grad());
Variable y = torch::randn({2, 2}, torch::requires_grad());
auto res = simple_fn(x, y);
res.backward(torch::ones({2, 2}), false, true);
Variable x_grad = y + torch::ones({2, 2});
Variable y_grad = x + torch::ones({2, 2}) * 2;
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
Variable grad_sum = 2 * x.grad() + y.grad();
auto x_hv = grad({grad_sum}, {x}, {torch::ones({2, 2})}, {}, true);
ASSERT_VARIABLE_EQ(x_hv[0], torch::ones({2, 2}));
ASSERT_VARIABLE_EQ(x.grad(), x_grad);
ASSERT_VARIABLE_EQ(y.grad(), y_grad);
}
TEST(AutogradAPITests, GradNonLeafTest) {
Variable x_init = torch::randn({2, 2}, torch::requires_grad());
Variable x = x_init;
Variable y = torch::randn({2, 2}, torch::requires_grad());
Variable grad_output = torch::ones({2, 2});
for (int i = 0; i < 5; ++ i) {
auto res = simple_fn(x, y);
auto input_grads = grad({res}, {x}, {grad_output}, {}, true);
Variable grad_x_expected = y + torch::ones({2, 2});
ASSERT_VARIABLE_EQ(input_grads[0], grad_x_expected);
ASSERT_FALSE(x.grad().defined());
ASSERT_FALSE(y.grad().defined());
x = x + 0.05 * input_grads[0];
}
float val_init = simple_fn(x_init, y).sum().item().toFloat();
float val_final = simple_fn(x, y).sum().item().toFloat();
ASSERT_TRUE(val_final > val_init);
x.backward(grad_output, false, true);
ASSERT_TRUE(x_init.grad().defined());
ASSERT_TRUE(y.grad().defined());
}
TEST(AutogradAPITests, GradUnreachableTest) {
Variable x = torch::ones({1}, torch::requires_grad());
Variable y = torch::ones({1}, torch::requires_grad());
Variable z = x * 2;
Variable w = y * 2;
auto grad_res = grad({x * 2}, {x, y}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
// This is slightly different than the case above, because z doesn't even
// have a grad accumulator allocated.
z = torch::ones({1}, torch::requires_grad());
grad_res = grad({x * 2}, {x, z}, {}, {}, false, true);
ASSERT_VARIABLE_EQ(grad_res[0], x * 2);
ASSERT_FALSE(grad_res[1].defined());
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
TEST(CustomAutogradTest, CustomFunction) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
ctx->saved_data["mul"] = mul;
ctx->save_for_backward({var1, var2});
return var1 + mul*var2 + var1*var2;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
int mul = ctx->saved_data["mul"].toInt();
auto saved = ctx->get_saved_variables();
auto var1 = saved[0];
auto var2 = saved[1];
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
return output;
}
};
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y = torch::randn({5,5}, torch::requires_grad());
auto res = MyFunction::apply(x,2,y);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
auto go = torch::ones({}, torch::requires_grad());
res.sum().backward(go, false, true);
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
}
TEST(CustomAutogradTest, FunctionReturnsInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable var1) {
return var1;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return {grad_output[0]*2};
}
};
Variable x(torch::ones(1, torch::requires_grad()));
MyFunction::apply(x).backward(torch::ones(1) , true, true);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1,2));
}
TEST(CustomAutogradTest, NoGradCustomFunction) {
// Custom Function should respect grad mode
struct MyOp : public Function<MyOp> {
static Variable forward(AutogradContext *ctx, Variable x) {
return x+1;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
}
static variable_list backward(AutogradContext *ctx, variable_list dy) {
return dy;
}
};
auto x = torch::ones({5,5}, torch::requires_grad());
{
at::NoGradGuard no_grad;
auto y = MyOp::apply(x);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
ASSERT_FALSE(y.requires_grad());
}
}
TEST(CustomAutogradTest, MarkNonDifferentiable) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable v) {
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
Variable output = v > 0;
ctx->mark_non_differentiable({output});
return output;
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return { (grad_output[0]*0.0) };
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
auto mask = MyFunction::apply(x);
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
ASSERT_FALSE(mask.requires_grad());
auto y = x.masked_fill(mask, 0);
y.sum().backward();
}
TEST(CustomAutogradTest, MarkNonDifferentiableMixed) {
struct MyFunction : public Function<MyFunction> {
static variable_list forward(AutogradContext *ctx, Variable input) {
Variable a = input+1;
Variable b = input+2;
ctx->mark_non_differentiable({a});
return {a,b};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
const Variable &grad_a = grad_output[0], &grad_b = grad_output[1];
EXPECT_VARIABLE_EQ(grad_a, torch::zeros({5,5}));
EXPECT_VARIABLE_EQ(grad_b, torch::ones({5,5}));
return {grad_b};
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
auto out = MyFunction::apply(x);
ASSERT_FALSE(out[0].requires_grad());
ASSERT_TRUE(out[1].requires_grad());
out[1].sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), torch::ones({5,5}));
}
TEST(CustomAutogradTest, MarkNonDifferentiableNone) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable input) {
auto output = input.clone();
ctx->mark_non_differentiable({output});
return output;
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
return {};
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
auto r = MyFunction::apply(x * x);
(r * x).sum().backward();
}
Support custom autograd functions in C++ (#23572) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572 ### **(The stack from #23020 was moved into this PR)** Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd). The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument. Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them. Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`. Example usage: ``` class MyFunction : public Function<MyFunction> { public: static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } }; ``` Then, it can be used with: ``` Variable x; MyFunction::apply(6, x); ``` Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](https://github.com/pytorch/pytorch/blob/ff23a02ac4fdf1fe76d5b24666333f1ea0a918b7/torch/autograd/function.py#L26). Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later. Differential Revision: D16583428 fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
2019-07-31 18:25:23 +00:00
TEST(CustomAutogradTest, ReturnLeafInplace) {
struct Inplace : public Function<Inplace> {
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
ctx->mark_dirty({a});
return {a.add_(b), b+2};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return {grad_output[0], grad_output[0] + grad_output[1]};
}
};
Variable x = torch::randn({5,5});
Variable y = torch::randn({5,5}, torch::requires_grad());
auto out = Inplace::apply(x,y);
auto &q = out[0];
ASSERT_TRUE(torch::equal(q, x));
ASSERT_TRUE(q.requires_grad());
q.sum().backward();
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
}
TEST(CustomAutogradTest, ReturnDuplicateInplace) {
struct DoubleInplace : public Function<DoubleInplace> {
static variable_list forward(AutogradContext *ctx, Variable x) {
x.mul_(2);
ctx->mark_dirty({x});
return {x,x};
}
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
return {grad_outputs[0]*2 + grad_outputs[1]*2};
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
ASSERT_THROWS_WITH(DoubleInplace::apply(x), "leaf Variable that requires grad");
// TODO ASSERT_THROWS_WITH(DoubleInplace::apply(x.clone()[0]), "only one output");
auto out = DoubleInplace::apply(x.clone());
ASSERT_TRUE(torch::equal(out[0],out[1]));
}
TEST(CustomAutogradTest, ReturnDuplicate) {
struct DoubleDuplicate : public Function<DoubleDuplicate> {
static variable_list forward(AutogradContext *ctx, Variable x) {
auto output = x*2;
return {output, output};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outputs) {
return {grad_outputs[0]*2 + grad_outputs[1]*2};
}
};
auto x = torch::randn({5,5}, torch::requires_grad());
auto out = DoubleDuplicate::apply(x);
ASSERT_TRUE(torch::equal(out[0],out[1]));
}
TEST(CustomAutogradTest, SaveEmptyForBackward) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable input) {
ctx->save_for_backward({Variable(), input, Variable()});
return input*input;
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
auto saved = ctx->get_saved_variables();
EXPECT_FALSE(saved[0].defined());
EXPECT_FALSE(saved[2].defined());
return {saved[1] * 2 * grad_output[0]};
}
};
Variable x = torch::randn({5,5}, torch::requires_grad());
auto y = MyFunction::apply(x);
y.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), 2*x);
}
TEST(CustomAutogradTest, InvalidGradients) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext *ctx, Variable x) {
return x*2;
}
static variable_list backward(AutogradContext *ctsx, variable_list grad_outputs) {
return {torch::randn(10, torch::dtype(torch::kFloat).requires_grad(true))};
}
};
auto input1 = torch::randn({5,5}, torch::dtype(torch::kFloat).requires_grad(true));
ASSERT_THROWS_WITH(
MyFunction::apply(input1).sum().backward(), "expected shape");
auto input2 = torch::randn(10, torch::dtype(torch::kDouble).requires_grad(true));
}
TEST(CustomAutogradTest, NoGradInput) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(AutogradContext*, variable_list grad_outputs) {
return grad_outputs;
}
};
Variable x = torch::randn({5,5}, torch::requires_grad());
Variable y;
{
at::NoGradGuard no_grad;
y = MyFunction::apply(x);
}
ASSERT_TRUE(x.requires_grad());
ASSERT_FALSE(y.grad_fn());
}
TEST(CustomAutogradTest, TooManyGrads) {
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable input) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
grad_output.insert(grad_output.end(), {Variable(), Variable()});
return grad_output;
}
};
}
TEST(CustomAutogradTest, DepNoGrad) {
struct F1 : public Function<F1> {
static variable_list forward(AutogradContext *ctx, Variable input) {
auto out = torch::randn(input.sizes());
ctx->mark_non_differentiable({out});
return {input, out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
return {grad_output[0]};
}
};
struct F2 : public Function<F2> {
static Variable forward(AutogradContext*, Variable input, Variable ignore) {
return input;
}
static variable_list backward(AutogradContext*, variable_list grad_output) {
return {grad_output[0], Variable()};
}
};
auto x = torch::randn(5, torch::requires_grad());
auto out = F1::apply(x);
Variable &a = out[0], &b = out[1];
b = b+1; // Separate F1 and F2 by another operation
ASSERT_TRUE(a.requires_grad());
ASSERT_FALSE(b.requires_grad());
auto c = F2::apply(a,b);
c.backward(torch::ones(c.sizes()), false, false);
ASSERT_VARIABLE_EQ(x.grad(), torch::ones(x.sizes()));
}
TEST(CustomAutogradTest, Reentrant) {
static Variable y_data = torch::randn({2, 2});
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext *ctx, Variable input) {
Variable output;
{
at::AutoGradMode enable_grad(true);
auto x = make_variable(input.tensor_data(), true);
auto y = make_variable(y_data.tensor_data(), true);
output = x*y;
ctx->saved_data["x"] = x;
ctx->saved_data["y"] = y;
ctx->saved_data["output_var"] = output;
}
return output.detach();
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
{
at::AutoGradMode enable_grad(true);
auto out = ctx->saved_data["output_var"].toTensor();
out.sum().backward();
}
return {ctx->saved_data["x"].toTensor().grad() * grad_output[0]};
}
};
auto x = torch::randn({2,2}, torch::requires_grad());
auto out = Reenter::apply(x);
out.sum().backward();
ASSERT_VARIABLE_EQ(x.grad(), y_data);
}
TEST(CustomAutogradTest, DeepReentrant) {
struct DeepReenter : public Function<DeepReenter> {
static Variable forward(AutogradContext *ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
// This should not stack overflow
Fix bugs in torch::tensor constructor (#28523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28523 New features: 1. Previously, `torch::tensor({true, false, true})` throws `"tensor_cpu" not implemented for 'Bool'`. After this PR, it produces the correct bool tensor, matching the Python API behavior. 2. Tensors with zero-size dimensions are now supported, e.g. `torch::tensor({{}, {}})` produces a tensor with sizes `{2, 0}`, matching the Python API behavior. BC-breaking bug fixes: 1. Previously, `torch::tensor({{1}, {2}})` produces a tensor of sizes `{2}`. After this PR, it produces a tensor of sizes `{2, 1}`, matching the Python API behavior. 2. Fixed semantics of `torch::tensor(1.1)`: it now returns a 0-dim tensor instead of a 1-dim tensor, matching the Python API behavior. 3. Previously, when passed a non-dtype `TensorOptions` to the `torch::tensor` constructor, it always produces a tensor of dtype `float`. After this PR, it produces tensor of different dtypes based on the dtype of the braced-init-list, matching the behavior of the no-options case. ```cpp // Previously: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float // Now: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double // As comparison, currently: torch::tensor({1, 2, 3}).dtype() -> int torch::tensor({{1, 2, 3}}).dtype() -> int torch::tensor({1., 2., 3.}).dtype() -> double torch::tensor({{1., 2., 3.}}).dtype() -> double ``` Notes: 1. From now on, the behavior of `at::tensor(scalar_value)` (which produces a 1-dim tensor) would be different from `torch::tensor(scalar_value)` (which produces a 0-dim tensor). I will fix the behavior of `at::tensor(scalar_value)` in a follow-up PR. 2. From now on, the behavior of `at::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a `float` tensor) would be different from `torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a an `int` tensor). I will fix this behavior of `at::tensor` constructor in a follow-up PR. Context for the changes in this PR: The motivation comes from fixing the "`torch::tensor({{1}, {2}})` gives tensor of wrong sizes" bug - in order to fix it, I have to move the handling of `at::ArrayRef` and `std::vector` into `InitListTensor` (see below on why we need to do this) and renamed `InitListTensor` to `TensorDataContainer`. After such changes, support for bool values comes out of the box without extra effort, and support for tensors with zero-size dimensions only requires adding a default constructor for `TensorDataContainer`, so I added those two in this PR. For the semantic change of `torch::tensor(1.1)`, it's actually more effort to preserve the original wrong behavior (i.e. we need to check the sizes of the tensor converted from `TensorDataContainer` and reshape any scalar tensor to a 1-D tensor). I think preserving the original wrong behavior doesn't give us much value, and since the above changes naturally fix the problem, we should just start using the right behavior instead. For the "constructor with non-dtype options behavior" fix, the code looks simpler and easier to reason about with the fix, so I included it in this PR. -------- Why we need to move the handling of `at::ArrayRef` and `std::vector` into `TensorDataContainer`: `torch::tensor({{1}, {2}})` can match this function overload: `torch::tensor(at::ArrayRef<int> values)`, because `{1}` and `{2}` can be treated as a list-initialization of an `int` value. However, this will produce a Tensor with sizes `{2}`, but we actually want a Tensor with sizes `{2, 1}`. In order to avoid matching this function overload, we removed the function overload and moved the ability to convert `at::ArrayRef<T>` (and similarly `std::vector<T>`) into `TensorDataContainer`, and since for braced-init-list the `TensorDataContainer(std::initializer_list<TensorDataContainer>)` constructor is always preferred over all other constructors, it will take the `std::initializer_list` path, and all is good. Test Plan: Imported from OSS Differential Revision: D18234625 Pulled By: yf225 fbshipit-source-id: 0f3f6912e82e2117d2103e31b74e7e97baaa8693
2019-10-31 19:51:18 +00:00
auto v = torch::tensor({8193}, torch::dtype(torch::kFloat).requires_grad(true));
DeepReenter::apply(v).sum().backward();
}
TEST(CustomAutogradTest, ReentrantPriority) {
static std::vector<int> order;
struct MyFunction : public Function<MyFunction> {
static Variable forward(AutogradContext*, Variable x) {
return x;
}
static variable_list backward(AutogradContext*, variable_list grad) {
order.push_back(0);
return grad;
}
};
struct Reenter : public Function<Reenter> {
static Variable forward(AutogradContext *ctx, Variable x) {
{
at::AutoGradMode enable_grad(true);
ctx->saved_data["x"] = make_variable(x.tensor_data(), true) -1;
}
return ctx->saved_data["x"].toTensor().detach();
}
static variable_list backward(AutogradContext*ctx, variable_list grad_output) {
order.push_back(1);
if (!ctx->saved_data["x"].toTensor().is_nonzero()) {
return grad_output;
}
{
at::AutoGradMode enable_grad(true);
apply(ctx->saved_data["x"].toTensor())[0].sum().backward();
return grad_output;
}
}
};
Fix bugs in torch::tensor constructor (#28523) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28523 New features: 1. Previously, `torch::tensor({true, false, true})` throws `"tensor_cpu" not implemented for 'Bool'`. After this PR, it produces the correct bool tensor, matching the Python API behavior. 2. Tensors with zero-size dimensions are now supported, e.g. `torch::tensor({{}, {}})` produces a tensor with sizes `{2, 0}`, matching the Python API behavior. BC-breaking bug fixes: 1. Previously, `torch::tensor({{1}, {2}})` produces a tensor of sizes `{2}`. After this PR, it produces a tensor of sizes `{2, 1}`, matching the Python API behavior. 2. Fixed semantics of `torch::tensor(1.1)`: it now returns a 0-dim tensor instead of a 1-dim tensor, matching the Python API behavior. 3. Previously, when passed a non-dtype `TensorOptions` to the `torch::tensor` constructor, it always produces a tensor of dtype `float`. After this PR, it produces tensor of different dtypes based on the dtype of the braced-init-list, matching the behavior of the no-options case. ```cpp // Previously: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> float // Now: torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({{1, 2, 3}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> int torch::tensor({1., 2., 3.}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double torch::tensor({{1., 2., 3.}}, torch::TensorOptions(/*non-dtype-options*/)).dtype() -> double // As comparison, currently: torch::tensor({1, 2, 3}).dtype() -> int torch::tensor({{1, 2, 3}}).dtype() -> int torch::tensor({1., 2., 3.}).dtype() -> double torch::tensor({{1., 2., 3.}}).dtype() -> double ``` Notes: 1. From now on, the behavior of `at::tensor(scalar_value)` (which produces a 1-dim tensor) would be different from `torch::tensor(scalar_value)` (which produces a 0-dim tensor). I will fix the behavior of `at::tensor(scalar_value)` in a follow-up PR. 2. From now on, the behavior of `at::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a `float` tensor) would be different from `torch::tensor({1, 2, 3}, torch::TensorOptions(/*non-dtype-options*/))` (which produces a an `int` tensor). I will fix this behavior of `at::tensor` constructor in a follow-up PR. Context for the changes in this PR: The motivation comes from fixing the "`torch::tensor({{1}, {2}})` gives tensor of wrong sizes" bug - in order to fix it, I have to move the handling of `at::ArrayRef` and `std::vector` into `InitListTensor` (see below on why we need to do this) and renamed `InitListTensor` to `TensorDataContainer`. After such changes, support for bool values comes out of the box without extra effort, and support for tensors with zero-size dimensions only requires adding a default constructor for `TensorDataContainer`, so I added those two in this PR. For the semantic change of `torch::tensor(1.1)`, it's actually more effort to preserve the original wrong behavior (i.e. we need to check the sizes of the tensor converted from `TensorDataContainer` and reshape any scalar tensor to a 1-D tensor). I think preserving the original wrong behavior doesn't give us much value, and since the above changes naturally fix the problem, we should just start using the right behavior instead. For the "constructor with non-dtype options behavior" fix, the code looks simpler and easier to reason about with the fix, so I included it in this PR. -------- Why we need to move the handling of `at::ArrayRef` and `std::vector` into `TensorDataContainer`: `torch::tensor({{1}, {2}})` can match this function overload: `torch::tensor(at::ArrayRef<int> values)`, because `{1}` and `{2}` can be treated as a list-initialization of an `int` value. However, this will produce a Tensor with sizes `{2}`, but we actually want a Tensor with sizes `{2, 1}`. In order to avoid matching this function overload, we removed the function overload and moved the ability to convert `at::ArrayRef<T>` (and similarly `std::vector<T>`) into `TensorDataContainer`, and since for braced-init-list the `TensorDataContainer(std::initializer_list<TensorDataContainer>)` constructor is always preferred over all other constructors, it will take the `std::initializer_list` path, and all is good. Test Plan: Imported from OSS Differential Revision: D18234625 Pulled By: yf225 fbshipit-source-id: 0f3f6912e82e2117d2103e31b74e7e97baaa8693
2019-10-31 19:51:18 +00:00
auto a = MyFunction::apply(torch::tensor({6}, torch::dtype(torch::kFloat).requires_grad(true)));
auto b = Reenter::apply(torch::tensor({9}, torch::dtype(torch::kFloat).requires_grad(true)));
auto v = a*b;
v.backward();
// All the reentrant tasks should be prioritized over the MyFunction backward
// task.
ASSERT_EQ(order.size(), 10);
ASSERT_EQ(std::count(order.begin(), order.end(), 1), 9);
ASSERT_EQ(order.back(), 0);
}
TEST(CustomAutogradTest, Hooks) {
Variable x = torch::ones({5,5}, torch::requires_grad());
Variable y = torch::ones({5,5})*4;
y.set_requires_grad(true);
int counter = 0;
std::function<void(int, Variable)> bw_hook([&counter](int inc, Variable grad){
counter += inc;
});
Variable z = x * x + x * 2 + x * y + y;
x.register_hook([&bw_hook](Variable grad){
bw_hook(0, grad);
});
auto hook_1 = z.register_hook([&bw_hook](Variable grad){
bw_hook(1, grad);
});
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 1);
auto hook_2 = z.register_hook([&bw_hook](Variable grad){
bw_hook(2, grad);
});
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 4);
z.remove_hook(hook_2);
z.backward(torch::ones({5,5}), true, true);
ASSERT_EQ(counter, 5);
std::function<Variable(Variable)> bw_hook_modify([](Variable grad){
return grad.mul(2);
});
z.remove_hook(hook_1);
z.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5,5}), true, false);
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*2);
y.register_hook(bw_hook_modify);
y.grad().zero_();
z.backward(torch::ones({5,5}), false, false);
ASSERT_VARIABLE_EQ(y.grad(), (x+1)*4);
ASSERT_THROWS_WITH(y.remove_hook(3), "Invalid index");
}
TEST(CustomAutogradTest, HookNone) {
struct NoneGradientFunction : public Function<NoneGradientFunction> {
static variable_list forward(AutogradContext *ctx, Variable x, Variable y) {
return {x,y};
}
static variable_list backward(AutogradContext *ctx, variable_list grad) {
return {grad[0], Variable()};
}
};
bool was_called = false;
auto hook = ([&was_called](Variable grad){
ASSERT_TRUE(grad.defined());
was_called = true;
});
auto x = torch::randn({5,5}, torch::requires_grad());
auto y = torch::randn({5,5});
auto out = NoneGradientFunction::apply(x,y);
Variable rx = x[0], ry = x[1];
rx.register_hook(hook);
ry.register_hook(hook);
(rx+ry).sum().backward();
ASSERT_TRUE(was_called);
}
// TODO add these tests if needed
// test_once_differentiable
// test_sparse_backward
// test_save_output_nr
// test_free_deep_graph_pyfunction
// test_naughty_anomaly_access
// test_naughty_autograd-function_stashing_ctx
// test_custom_autograd_repeated_grad_grad
// test_return_leaf
// test_anomaly_detect_nan
// test_no_grad_copy