mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-15 21:00:47 +00:00
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23572
### **(The stack from #23020 was moved into this PR)**
Adding API for custom autograd operations, with user defined forward and backward, [like in python](https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd).
The custom operation should be a subclass of Function, with static forward and backward functions. `forward()` can accept any arguments similar to the Python API and `backward()` should accept a variable list as an argument.
Both `forward()` and `backward() `accept a AutogradContext* which can be used to share data between them.
Variables can be saved in the context using `save_for_backward()` and other data can be saved in the map `save` in the form of `<std::string, at::IValue>` pairs. Variables saved in forward can be accessed with `get_saved_variables()`.
Example usage:
```
class MyFunction : public Function<MyFunction> {
public:
static variable_list forward(AutogradContext *ctx, int n, Variable var) {
// Save data for backward in context
ctx->saved_data["n"] = n;
return {var};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
// Use data saved in forward
auto n = ctx->saved_data["n"].toInt();
return {grad_output[0]*n};
}
};
```
Then, it can be used with:
```
Variable x;
MyFunction::apply(6, x);
```
Also AutogradContext has methods to mark outputs as non differentiable and mark inputs as dirty similar to the [Python API](ff23a02ac4/torch/autograd/function.py (L26)).
Test Plan: Added tests for the custom autograd function API based on test_autograd.py. Currently only the tests for the basic functionality have been added. More tests will be added later.
Differential Revision: D16583428
fbshipit-source-id: 0bd42f19ce37bcd99d3080d16195ad74d40d0413
129 lines
3.8 KiB
C++
129 lines
3.8 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <torch/autograd.h>
|
|
|
|
#include <torch/utils.h>
|
|
#include <test/cpp/api/support.h>
|
|
|
|
using namespace torch::autograd;
|
|
|
|
#define ASSERT_VARIABLE_EQ(a,b) ASSERT_TRUE(torch::allclose((a),(b)))
|
|
|
|
std::string graph_desc(std::shared_ptr<Node> node) {
|
|
if (!node) {
|
|
return "None";
|
|
}
|
|
auto result = node->name() + "(";
|
|
auto next_edges = node->next_edges();
|
|
for(auto& edge : next_edges) {
|
|
result += graph_desc(edge.function);
|
|
}
|
|
return result+")";
|
|
}
|
|
|
|
TEST(CustomAutogradTest, CustomFunction) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable var1, int mul, Variable var2) {
|
|
ctx->saved_data["mul"] = mul;
|
|
ctx->save_for_backward({var1, var2});
|
|
return {var1 + mul*var2 + var1*var2};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
int mul = ctx->saved_data["mul"].toInt();
|
|
auto saved = ctx->get_saved_variables();
|
|
auto var1 = saved[0];
|
|
auto var2 = saved[1];
|
|
variable_list output = {grad_output[0] + grad_output[0]*var2, Variable(), grad_output[0] * mul + grad_output[0] * var1};
|
|
return output;
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5}, torch::requires_grad());
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
auto res = MyFunction::apply(x,2,y)[0];
|
|
auto go = torch::ones({}, torch::requires_grad());
|
|
res.sum().backward(go, false, true);
|
|
|
|
ASSERT_VARIABLE_EQ(x.grad(), y + torch::ones({5,5}));
|
|
ASSERT_VARIABLE_EQ(y.grad(), x + torch::ones({5,5})*2);
|
|
}
|
|
|
|
TEST(CustomAutogradTest, FunctionReturnsInput) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable var1) {
|
|
return {var1};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0]*2};
|
|
}
|
|
};
|
|
|
|
Variable x(torch::ones(1, torch::requires_grad()));
|
|
MyFunction::apply(x)[0].backward(torch::ones(1) , true, true);
|
|
ASSERT_VARIABLE_EQ(x.grad(), torch::full(1,2));
|
|
}
|
|
|
|
TEST(CustomAutogradTest, NoGradCustomFunction) {
|
|
// Custom Function should respect grad mode
|
|
struct MyOp : public Function<MyOp> {
|
|
static variable_list forward(AutogradContext *ctx, Variable x) {
|
|
return {x+1};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list dy) {
|
|
return dy;
|
|
}
|
|
};
|
|
|
|
auto x = torch::ones({5,5}, torch::requires_grad());
|
|
{
|
|
at::NoGradGuard no_grad;
|
|
auto y = MyOp::apply(x)[0];
|
|
ASSERT_FALSE(y.requires_grad());
|
|
}
|
|
}
|
|
|
|
TEST(CustomAutogradTest, MarkNonDifferentiable) {
|
|
struct MyFunction : public Function<MyFunction> {
|
|
static variable_list forward(AutogradContext *ctx, Variable v) {
|
|
Variable output = v > 0;
|
|
ctx->mark_non_differentiable({output});
|
|
return {output};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return { (grad_output[0]*0.0) };
|
|
}
|
|
};
|
|
|
|
auto x = torch::randn({5,5}, torch::requires_grad());
|
|
auto mask = MyFunction::apply(x)[0];
|
|
ASSERT_FALSE(mask.requires_grad());
|
|
auto y = x.masked_fill(mask, 0);
|
|
y.sum().backward();
|
|
}
|
|
|
|
TEST(CustomAutogradTest, ReturnLeafInplace) {
|
|
struct Inplace : public Function<Inplace> {
|
|
static variable_list forward(AutogradContext *ctx, Variable a, Variable b) {
|
|
ctx->mark_dirty({a});
|
|
return {a.add_(b), b+2};
|
|
}
|
|
|
|
static variable_list backward(AutogradContext *ctx, variable_list grad_output) {
|
|
return {grad_output[0], grad_output[0] + grad_output[1]};
|
|
}
|
|
};
|
|
|
|
Variable x = torch::randn({5,5});
|
|
Variable y = torch::randn({5,5}, torch::requires_grad());
|
|
|
|
auto out = Inplace::apply(x,y);
|
|
auto &q = out[0];
|
|
ASSERT_TRUE(torch::equal(q, x));
|
|
ASSERT_TRUE(q.requires_grad());
|
|
q.sum().backward();
|
|
ASSERT_VARIABLE_EQ(y.grad(), torch::ones({5,5}));
|
|
}
|