Register gradient op with engine (#21205)

Summary:
cc dreiss
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21205

Differential Revision: D15578948

Pulled By: bddppq

fbshipit-source-id: ef285174e8637daef624c8088ebd903a70582345
This commit is contained in:
Junjie Bai 2019-05-31 18:45:48 -07:00 committed by Facebook Github Bot
parent daa1e2de1a
commit 4c19421f16
2 changed files with 49 additions and 12 deletions

View file

@ -1318,6 +1318,13 @@ C10_DECLARE_REGISTRY(
C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__))
#endif
#ifdef CAFFE2_NO_GRADIENT_OPS
#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) /* No gradients. */
#else
#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) \
C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR_WITH_ENGINE(__VA_ARGS__))
#endif
C10_DECLARE_REGISTRY(
CUDAOperatorRegistry,
OperatorBase,

View file

@ -317,6 +317,22 @@ TEST(NetTest, TestScaffoldingDAGNet) {
EXPECT_TRUE(net->Run());
}
class FooGradientOp : public JustTest {
public:
using JustTest::JustTest;
string type() override {
return "FooGradient";
}
};
class FooGradientDummyEngineOp : public JustTest {
public:
using JustTest::JustTest;
string type() override {
return "FooGradientDummyEngine";
}
};
class GetFooGradient : public GradientMakerBase {
using GradientMakerBase::GradientMakerBase;
vector<OperatorDef> GetGradientDefs() override {
@ -328,6 +344,12 @@ class GetFooGradient : public GradientMakerBase {
}
};
GRADIENT_OPERATOR_SCHEMA(FooGradient).NumInputs(1).NumOutputs(1);
REGISTER_CPU_GRADIENT_OPERATOR(FooGradient, FooGradientOp)
REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(
FooGradient,
DUMMY_ENGINE,
FooGradientDummyEngineOp)
REGISTER_GRADIENT(Foo, GetFooGradient);
TEST(OperatorGradientRegistryTest, GradientSimple) {
@ -342,23 +364,31 @@ TEST(OperatorGradientRegistryTest, GradientSimple) {
GradientOpsMeta meta = GetGradientForOp(def, g_output);
// Check the names, input and output.
EXPECT_EQ(meta.ops_.size(), 1);
const OperatorDef& grad_op = meta.ops_[0];
EXPECT_EQ(grad_op.type(), "FooGradient");
EXPECT_EQ(grad_op.name(), "");
EXPECT_EQ(grad_op.input_size(), 1);
EXPECT_EQ(grad_op.output_size(), 1);
EXPECT_EQ(grad_op.input(0), "out_grad");
EXPECT_EQ(grad_op.output(0), "in_grad");
const OperatorDef& grad_op_def = meta.ops_[0];
EXPECT_EQ(grad_op_def.type(), "FooGradient");
EXPECT_EQ(grad_op_def.name(), "");
EXPECT_EQ(grad_op_def.input_size(), 1);
EXPECT_EQ(grad_op_def.output_size(), 1);
EXPECT_EQ(grad_op_def.input(0), "out_grad");
EXPECT_EQ(grad_op_def.output(0), "in_grad");
// Checks the engine, device option and arguments.
EXPECT_EQ(grad_op.engine(), "DUMMY_ENGINE");
EXPECT_EQ(grad_op.device_option().device_type(), PROTO_CPU);
EXPECT_EQ(grad_op.arg_size(), 1);
EXPECT_EQ(grad_op.arg(0).SerializeAsString(),
MakeArgument<int>("arg", 1).SerializeAsString());
EXPECT_EQ(grad_op_def.engine(), "DUMMY_ENGINE");
EXPECT_EQ(grad_op_def.device_option().device_type(), PROTO_CPU);
EXPECT_EQ(grad_op_def.arg_size(), 1);
EXPECT_EQ(
grad_op_def.arg(0).SerializeAsString(),
MakeArgument<int>("arg", 1).SerializeAsString());
// Checks the gradient name for input.
EXPECT_EQ(meta.g_input_.size(), 1);
EXPECT_TRUE(meta.g_input_[0].IsDense());
EXPECT_EQ(meta.g_input_[0].dense_, "in_grad");
Workspace ws;
EXPECT_NE(ws.CreateBlob("out_grad"), nullptr);
unique_ptr<OperatorBase> grad_op = CreateOperator(grad_op_def, &ws);
EXPECT_NE(nullptr, grad_op.get());
EXPECT_EQ(
static_cast<JustTest*>(grad_op.get())->type(), "FooGradientDummyEngine");
}
TEST(EnginePrefTest, PerOpEnginePref) {