mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Register gradient op with engine (#21205)
Summary: cc dreiss Pull Request resolved: https://github.com/pytorch/pytorch/pull/21205 Differential Revision: D15578948 Pulled By: bddppq fbshipit-source-id: ef285174e8637daef624c8088ebd903a70582345
This commit is contained in:
parent
daa1e2de1a
commit
4c19421f16
2 changed files with 49 additions and 12 deletions
|
|
@ -1318,6 +1318,13 @@ C10_DECLARE_REGISTRY(
|
|||
C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR(__VA_ARGS__))
|
||||
#endif
|
||||
|
||||
#ifdef CAFFE2_NO_GRADIENT_OPS
|
||||
#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) /* No gradients. */
|
||||
#else
|
||||
#define REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(...) \
|
||||
C10_MACRO_EXPAND(REGISTER_CPU_OPERATOR_WITH_ENGINE(__VA_ARGS__))
|
||||
#endif
|
||||
|
||||
C10_DECLARE_REGISTRY(
|
||||
CUDAOperatorRegistry,
|
||||
OperatorBase,
|
||||
|
|
|
|||
|
|
@ -317,6 +317,22 @@ TEST(NetTest, TestScaffoldingDAGNet) {
|
|||
EXPECT_TRUE(net->Run());
|
||||
}
|
||||
|
||||
class FooGradientOp : public JustTest {
|
||||
public:
|
||||
using JustTest::JustTest;
|
||||
string type() override {
|
||||
return "FooGradient";
|
||||
}
|
||||
};
|
||||
|
||||
class FooGradientDummyEngineOp : public JustTest {
|
||||
public:
|
||||
using JustTest::JustTest;
|
||||
string type() override {
|
||||
return "FooGradientDummyEngine";
|
||||
}
|
||||
};
|
||||
|
||||
class GetFooGradient : public GradientMakerBase {
|
||||
using GradientMakerBase::GradientMakerBase;
|
||||
vector<OperatorDef> GetGradientDefs() override {
|
||||
|
|
@ -328,6 +344,12 @@ class GetFooGradient : public GradientMakerBase {
|
|||
}
|
||||
};
|
||||
|
||||
GRADIENT_OPERATOR_SCHEMA(FooGradient).NumInputs(1).NumOutputs(1);
|
||||
REGISTER_CPU_GRADIENT_OPERATOR(FooGradient, FooGradientOp)
|
||||
REGISTER_CPU_GRADIENT_OPERATOR_WITH_ENGINE(
|
||||
FooGradient,
|
||||
DUMMY_ENGINE,
|
||||
FooGradientDummyEngineOp)
|
||||
REGISTER_GRADIENT(Foo, GetFooGradient);
|
||||
|
||||
TEST(OperatorGradientRegistryTest, GradientSimple) {
|
||||
|
|
@ -342,23 +364,31 @@ TEST(OperatorGradientRegistryTest, GradientSimple) {
|
|||
GradientOpsMeta meta = GetGradientForOp(def, g_output);
|
||||
// Check the names, input and output.
|
||||
EXPECT_EQ(meta.ops_.size(), 1);
|
||||
const OperatorDef& grad_op = meta.ops_[0];
|
||||
EXPECT_EQ(grad_op.type(), "FooGradient");
|
||||
EXPECT_EQ(grad_op.name(), "");
|
||||
EXPECT_EQ(grad_op.input_size(), 1);
|
||||
EXPECT_EQ(grad_op.output_size(), 1);
|
||||
EXPECT_EQ(grad_op.input(0), "out_grad");
|
||||
EXPECT_EQ(grad_op.output(0), "in_grad");
|
||||
const OperatorDef& grad_op_def = meta.ops_[0];
|
||||
EXPECT_EQ(grad_op_def.type(), "FooGradient");
|
||||
EXPECT_EQ(grad_op_def.name(), "");
|
||||
EXPECT_EQ(grad_op_def.input_size(), 1);
|
||||
EXPECT_EQ(grad_op_def.output_size(), 1);
|
||||
EXPECT_EQ(grad_op_def.input(0), "out_grad");
|
||||
EXPECT_EQ(grad_op_def.output(0), "in_grad");
|
||||
// Checks the engine, device option and arguments.
|
||||
EXPECT_EQ(grad_op.engine(), "DUMMY_ENGINE");
|
||||
EXPECT_EQ(grad_op.device_option().device_type(), PROTO_CPU);
|
||||
EXPECT_EQ(grad_op.arg_size(), 1);
|
||||
EXPECT_EQ(grad_op.arg(0).SerializeAsString(),
|
||||
MakeArgument<int>("arg", 1).SerializeAsString());
|
||||
EXPECT_EQ(grad_op_def.engine(), "DUMMY_ENGINE");
|
||||
EXPECT_EQ(grad_op_def.device_option().device_type(), PROTO_CPU);
|
||||
EXPECT_EQ(grad_op_def.arg_size(), 1);
|
||||
EXPECT_EQ(
|
||||
grad_op_def.arg(0).SerializeAsString(),
|
||||
MakeArgument<int>("arg", 1).SerializeAsString());
|
||||
// Checks the gradient name for input.
|
||||
EXPECT_EQ(meta.g_input_.size(), 1);
|
||||
EXPECT_TRUE(meta.g_input_[0].IsDense());
|
||||
EXPECT_EQ(meta.g_input_[0].dense_, "in_grad");
|
||||
|
||||
Workspace ws;
|
||||
EXPECT_NE(ws.CreateBlob("out_grad"), nullptr);
|
||||
unique_ptr<OperatorBase> grad_op = CreateOperator(grad_op_def, &ws);
|
||||
EXPECT_NE(nullptr, grad_op.get());
|
||||
EXPECT_EQ(
|
||||
static_cast<JustTest*>(grad_op.get())->type(), "FooGradientDummyEngine");
|
||||
}
|
||||
|
||||
TEST(EnginePrefTest, PerOpEnginePref) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue